Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
86c243b4
Unverified
Commit
86c243b4
authored
Jun 23, 2020
by
Mufei Li
Committed by
GitHub
Jun 23, 2020
Browse files
Fix (#1684)
parent
85e660cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
269 additions
and
150 deletions
+269
-150
python/dgl/batched_heterograph.py
python/dgl/batched_heterograph.py
+241
-150
tests/compute/test_batched_heterograph.py
tests/compute/test_batched_heterograph.py
+28
-0
No files found.
python/dgl/batched_heterograph.py
View file @
86c243b4
...
@@ -31,15 +31,30 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
...
@@ -31,15 +31,30 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
Parameters
Parameters
----------
----------
graph_list : iterable
gidx : HeteroGraphIndex
A collection of :class:`~dgl.DGLHeteroGraph` to be batched.
Graph index object.
node_attrs : None or dict
ntypes : list of str, pair of list of str
The node attributes to be batched. If ``None``, the resulted graph will not have
Node type list. ``ntypes[i]`` stores the name of node type i.
features. If ``dict``, it maps str to str or iterable. The keys represent names of
If a pair is given, the graph created is a uni-directional bipartite graph,
node types and the values represent the node features to be batched for the
and its SRC node types and DST node types are given as in the pair.
corresponding type. By default, we use all features for all types of nodes.
etypes : list of str
edge_attrs : None or dict
Edge type list. ``etypes[i]`` stores the name of edge type i.
Same as for the case of :attr:`node_attrs`.
node_frames : list of FrameRef, optional
Node feature storage. If None, empty frame is created.
Otherwise, ``node_frames[i]`` stores the node features
of node type i. (default: None)
edge_frames : list of FrameRef, optional
Edge feature storage. If None, empty frame is created.
Otherwise, ``edge_frames[i]`` stores the edge features
of edge type i. (default: None)
batch_size : int
Number of heterogeneous graphs in the batch.
batch_num_nodes : list of list
batch_num_nodes[i][j] gives the number of nodes of the i-th type
in the j-th graph.
batch_num_edges : list of list
batch_num_edges[i][j] gives the number of edges of the i-th type
in the j-th graph.
Examples
Examples
--------
--------
...
@@ -132,150 +147,25 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
...
@@ -132,150 +147,25 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
>>> g4.nodes['game'].data['h1']
>>> g4.nodes['game'].data['h1']
tensor([[1.]])
tensor([[1.]])
"""
"""
def
__init__
(
self
,
graph_list
,
node_attrs
,
edge_attrs
):
def
__init__
(
self
,
# Sanity check. Make sure all graphs have the same node/edge types, in the same order.
gidx
,
ref_canonical_etypes
=
graph_list
[
0
].
canonical_etypes
ntypes
,
ref_ntypes
=
graph_list
[
0
].
ntypes
etypes
,
ref_etypes
=
graph_list
[
0
].
etypes
node_frames
,
for
i
in
range
(
1
,
len
(
graph_list
)):
edge_frames
,
g_i
=
graph_list
[
i
]
batch_size
,
assert
g_i
.
ntypes
==
ref_ntypes
,
\
batch_num_nodes
,
'The node types of graph {:d} and {:d} should be the same.'
.
format
(
0
,
i
)
batch_num_edges
):
assert
g_i
.
canonical_etypes
==
ref_canonical_etypes
,
\
super
(
BatchedDGLHeteroGraph
,
self
).
__init__
(
gidx
=
gidx
,
'The canonical edge types of graph {:d} and {:d} should be the same.'
.
format
(
0
,
i
)
ntypes
=
ntypes
,
etypes
=
etypes
,
# Sanity check. Make sure all graphs have same node/edge features in terns of name, size
# and dtype if the number of nodes is nonzero.
ref_node_feats
=
dict
()
for
nty
in
ref_ntypes
:
for
i
,
graph
in
enumerate
(
graph_list
):
# No nodes, skip it
if
graph
.
number_of_nodes
(
nty
)
==
0
:
continue
# Use this for reference of feature names, shape and dtype
if
nty
not
in
ref_node_feats
:
ref_node_feats
[
nty
]
=
(
i
,
graph
.
node_attr_schemes
(
nty
))
continue
# Name check
assert
set
(
ref_node_feats
[
nty
][
1
].
keys
())
==
\
set
(
graph
.
node_attr_schemes
(
nty
).
keys
()),
\
'The node features of graph {:d} and {:d} for node type {} should be the '
\
'same.'
.
format
(
ref_node_feats
[
nty
][
0
],
i
,
nty
)
# Size and dtype check
for
nfeats
in
ref_node_feats
[
nty
][
1
].
keys
():
assert
ref_node_feats
[
nty
][
1
][
nfeats
]
==
\
graph
.
node_attr_schemes
(
nty
)[
nfeats
],
\
'For graph {:d} and {:d}, the size and dtype for feature {} of '
\
'{}-typed nodes should be the same.'
.
format
(
ref_node_feats
[
nty
][
0
],
i
,
nfeats
,
nty
)
ref_edge_feats
=
dict
()
for
ety
in
ref_canonical_etypes
:
for
i
,
graph
in
enumerate
(
graph_list
):
# No edges, skip it
if
graph
.
number_of_edges
(
ety
)
==
0
:
continue
# Use this for reference of feature names, shape and dtype
if
ety
not
in
ref_edge_feats
:
ref_edge_feats
[
ety
]
=
(
i
,
graph
.
edge_attr_schemes
(
ety
))
continue
# Name check
assert
set
(
ref_edge_feats
[
ety
][
1
].
keys
())
==
\
set
(
graph
.
edge_attr_schemes
(
ety
).
keys
()),
\
'The edge features of graph {:d} and {:d} for edge type {} should be the '
\
'same.'
.
format
(
ref_edge_feats
[
ety
][
0
],
i
,
ety
)
# Size and dtype check
for
efeats
in
ref_edge_feats
[
ety
][
1
].
keys
():
assert
ref_edge_feats
[
ety
][
1
][
efeats
]
==
\
graph
.
edge_attr_schemes
(
ety
)[
efeats
],
\
'For graph {:d} and {:d}, the size and dtype for feature {} of '
\
'{}-typed edges should be the same.'
.
format
(
ref_edge_feats
[
ety
][
0
],
i
,
efeats
,
ety
)
def
_init_attrs
(
types
,
attrs
,
mode
):
formatted_attrs
=
{
t
:
[]
for
t
in
types
}
if
is_all
(
attrs
):
for
typ
in
types
:
if
mode
==
'node'
:
# Handle the case where the nodes of a type have no features
formatted_attrs
[
typ
]
=
list
(
ref_node_feats
.
get
(
typ
,
(
None
,
dict
()))[
1
].
keys
())
elif
mode
==
'edge'
:
# Handle the case where the edges of a type have no features
formatted_attrs
[
typ
]
=
list
(
ref_edge_feats
.
get
(
typ
,
(
None
,
dict
()))[
1
].
keys
())
elif
isinstance
(
attrs
,
dict
):
for
typ
,
v
in
attrs
.
items
():
if
isinstance
(
v
,
str
):
formatted_attrs
[
typ
]
=
[
v
]
elif
isinstance
(
v
,
Iterable
):
formatted_attrs
[
typ
]
=
list
(
v
)
elif
v
is
not
None
:
raise
ValueError
(
'Expected {} attrs for type {} to be str '
'or iterable, got {}'
.
format
(
mode
,
typ
,
type
(
v
)))
elif
attrs
is
not
None
:
raise
ValueError
(
'Expected {} attrs to be of type None or dict,'
'got type {}'
.
format
(
mode
,
type
(
attrs
)))
return
formatted_attrs
node_attrs
=
_init_attrs
(
ref_ntypes
,
node_attrs
,
'node'
)
edge_attrs
=
_init_attrs
(
ref_canonical_etypes
,
edge_attrs
,
'edge'
)
node_frames
=
[]
for
tid
,
typ
in
enumerate
(
ref_ntypes
):
if
len
(
node_attrs
[
typ
])
==
0
:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
node_frames
.
append
(
None
)
else
:
# NOTE: following code will materialize the columns of the input graphs.
cols
=
{
key
:
F
.
cat
([
gr
.
_node_frames
[
tid
][
key
]
for
gr
in
graph_list
if
gr
.
number_of_nodes
(
typ
)
>
0
],
dim
=
0
)
for
key
in
node_attrs
[
typ
]}
node_frames
.
append
(
FrameRef
(
Frame
(
cols
)))
edge_frames
=
[]
for
tid
,
typ
in
enumerate
(
ref_canonical_etypes
):
if
len
(
edge_attrs
[
typ
])
==
0
:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
edge_frames
.
append
(
None
)
else
:
# NOTE: following code will materialize the columns of the input graphs.
cols
=
{
key
:
F
.
cat
([
gr
.
_edge_frames
[
tid
][
key
]
for
gr
in
graph_list
if
gr
.
number_of_edges
(
typ
)
>
0
],
dim
=
0
)
for
key
in
edge_attrs
[
typ
]}
edge_frames
.
append
(
FrameRef
(
Frame
(
cols
)))
# Create graph index for the batched graph
metagraph
=
graph_list
[
0
].
_graph
.
metagraph
batched_index
=
heterograph_index
.
disjoint_union
(
metagraph
,
[
g
.
_graph
for
g
in
graph_list
])
super
(
BatchedDGLHeteroGraph
,
self
).
__init__
(
gidx
=
batched_index
,
ntypes
=
ref_ntypes
,
etypes
=
ref_etypes
,
node_frames
=
node_frames
,
node_frames
=
node_frames
,
edge_frames
=
edge_frames
)
edge_frames
=
edge_frames
)
# extra members
# extra members
self
.
_batch_size
=
0
self
.
_batch_size
=
batch_size
# Store number of nodes/edge based on the id of node/edge types as we need
self
.
_batch_num_nodes
=
batch_num_nodes
# to handle both edge type and canonical edge type.
self
.
_batch_num_edges
=
batch_num_edges
self
.
_batch_num_nodes
=
[[]
for
_
in
range
(
len
(
ref_ntypes
))]
self
.
_batch_num_edges
=
[[]
for
_
in
range
(
len
(
ref_etypes
))]
for
grh
in
graph_list
:
if
isinstance
(
grh
,
BatchedDGLHeteroGraph
):
# Handle input graphs that are already batched
self
.
_batch_size
+=
grh
.
_batch_size
for
ntype_id
in
range
(
len
(
ref_ntypes
)):
self
.
_batch_num_nodes
[
ntype_id
].
extend
(
grh
.
_batch_num_nodes
[
ntype_id
])
for
etype_id
in
range
(
len
(
ref_etypes
)):
self
.
_batch_num_edges
[
etype_id
].
extend
(
grh
.
_batch_num_edges
[
etype_id
])
else
:
self
.
_batch_size
+=
1
for
ntype_id
in
range
(
len
(
ref_ntypes
)):
self
.
_batch_num_nodes
[
ntype_id
].
append
(
grh
.
_graph
.
number_of_nodes
(
ntype_id
))
for
etype_id
in
range
(
len
(
ref_etypes
)):
self
.
_batch_num_edges
[
etype_id
].
append
(
grh
.
_graph
.
number_of_edges
(
etype_id
))
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -356,6 +246,62 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
...
@@ -356,6 +246,62 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
"""
"""
return
self
.
_batch_num_edges
[
self
.
get_etype_id
(
etype
)]
return
self
.
_batch_num_edges
[
self
.
get_etype_id
(
etype
)]
def
to
(
self
,
ctx
,
**
kwargs
):
# pylint: disable=invalid-name
"""Move ndata, edata and graph structure to the targeted device context (cpu/gpu).
Parameters
----------
ctx : Framework-specific device context object
The context to move data to.
kwargs : Key-word arguments.
Key-word arguments fed to the framework copy function.
Returns
-------
g : BatchedDGLHeteroGraph
Moved BatchedDGLHeteroGraph of the targeted mode.
Examples
--------
The following example uses PyTorch backend.
>>> # Create the first graph and set features for nodes of type 'user'
>>> g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]})
>>> g1.nodes['user'].data['h1'] = th.tensor([[0.], [1.]])
>>> # Create the second graph and set features for nodes of type 'user'
>>> g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0)]})
>>> g2.nodes['user'].data['h1'] = th.tensor([[0.]])
>>> # Batch the graphs
>>> bg = dgl.batch_hetero([g1, g2])
>>> # Move the graph topology and features to GPU
>>> bg1 = bg.to(torch.device('cuda:0'))
>>> print(bg1.device)
device(type='cuda', index=0)
>>> print(bg.device)
device(type='cpu')
"""
new_nframes
=
[]
for
nframe
in
self
.
_node_frames
:
new_feats
=
{
k
:
F
.
copy_to
(
feat
,
ctx
)
for
k
,
feat
in
nframe
.
items
()}
new_nframes
.
append
(
FrameRef
(
Frame
(
new_feats
)))
new_eframes
=
[]
for
eframe
in
self
.
_edge_frames
:
new_feats
=
{
k
:
F
.
copy_to
(
feat
,
ctx
)
for
k
,
feat
in
eframe
.
items
()}
new_eframes
.
append
(
FrameRef
(
Frame
(
new_feats
)))
# TODO(minjie): replace the following line with the commented one to enable GPU graph.
new_gidx
=
self
.
_graph
#new_gidx = self._graph.copy_to(utils.to_dgl_context(ctx))
return
BatchedDGLHeteroGraph
(
gidx
=
new_gidx
,
ntypes
=
self
.
ntypes
,
etypes
=
self
.
etypes
,
node_frames
=
new_nframes
,
edge_frames
=
new_eframes
,
batch_size
=
self
.
batch_size
,
batch_num_nodes
=
self
.
_batch_num_nodes
,
batch_num_edges
=
self
.
_batch_num_edges
)
def
unbatch_hetero
(
graph
):
def
unbatch_hetero
(
graph
):
"""Return the list of heterographs in this batch.
"""Return the list of heterographs in this batch.
...
@@ -438,4 +384,149 @@ def batch_hetero(graph_list, node_attrs=ALL, edge_attrs=ALL):
...
@@ -438,4 +384,149 @@ def batch_hetero(graph_list, node_attrs=ALL, edge_attrs=ALL):
BatchedDGLHeteroGraph
BatchedDGLHeteroGraph
unbatch_hetero
unbatch_hetero
"""
"""
return
BatchedDGLHeteroGraph
(
graph_list
,
node_attrs
,
edge_attrs
)
# Sanity check. Make sure all graphs have the same node/edge types, in the same order.
ref_canonical_etypes
=
graph_list
[
0
].
canonical_etypes
ref_ntypes
=
graph_list
[
0
].
ntypes
ref_etypes
=
graph_list
[
0
].
etypes
for
i
in
range
(
1
,
len
(
graph_list
)):
g_i
=
graph_list
[
i
]
assert
g_i
.
ntypes
==
ref_ntypes
,
\
'The node types of graph {:d} and {:d} should be the same.'
.
format
(
0
,
i
)
assert
g_i
.
canonical_etypes
==
ref_canonical_etypes
,
\
'The canonical edge types of graph {:d} and {:d} should be the same.'
.
format
(
0
,
i
)
# Sanity check. Make sure all graphs have same node/edge features in terns of name, size
# and dtype if the number of nodes is nonzero.
ref_node_feats
=
dict
()
for
nty
in
ref_ntypes
:
for
i
,
graph
in
enumerate
(
graph_list
):
# No nodes, skip it
if
graph
.
number_of_nodes
(
nty
)
==
0
:
continue
# Use this for reference of feature names, shape and dtype
if
nty
not
in
ref_node_feats
:
ref_node_feats
[
nty
]
=
(
i
,
graph
.
node_attr_schemes
(
nty
))
continue
# Name check
assert
set
(
ref_node_feats
[
nty
][
1
].
keys
())
==
\
set
(
graph
.
node_attr_schemes
(
nty
).
keys
()),
\
'The node features of graph {:d} and {:d} for node type {} should be the '
\
'same.'
.
format
(
ref_node_feats
[
nty
][
0
],
i
,
nty
)
# Size and dtype check
for
nfeats
in
ref_node_feats
[
nty
][
1
].
keys
():
assert
ref_node_feats
[
nty
][
1
][
nfeats
]
==
\
graph
.
node_attr_schemes
(
nty
)[
nfeats
],
\
'For graph {:d} and {:d}, the size and dtype for feature {} of '
\
'{}-typed nodes should be the same.'
.
format
(
ref_node_feats
[
nty
][
0
],
i
,
nfeats
,
nty
)
ref_edge_feats
=
dict
()
for
ety
in
ref_canonical_etypes
:
for
i
,
graph
in
enumerate
(
graph_list
):
# No edges, skip it
if
graph
.
number_of_edges
(
ety
)
==
0
:
continue
# Use this for reference of feature names, shape and dtype
if
ety
not
in
ref_edge_feats
:
ref_edge_feats
[
ety
]
=
(
i
,
graph
.
edge_attr_schemes
(
ety
))
continue
# Name check
assert
set
(
ref_edge_feats
[
ety
][
1
].
keys
())
==
\
set
(
graph
.
edge_attr_schemes
(
ety
).
keys
()),
\
'The edge features of graph {:d} and {:d} for edge type {} should be the '
\
'same.'
.
format
(
ref_edge_feats
[
ety
][
0
],
i
,
ety
)
# Size and dtype check
for
efeats
in
ref_edge_feats
[
ety
][
1
].
keys
():
assert
ref_edge_feats
[
ety
][
1
][
efeats
]
==
\
graph
.
edge_attr_schemes
(
ety
)[
efeats
],
\
'For graph {:d} and {:d}, the size and dtype for feature {} of '
\
'{}-typed edges should be the same.'
.
format
(
ref_edge_feats
[
ety
][
0
],
i
,
efeats
,
ety
)
def
_init_attrs
(
types
,
attrs
,
mode
):
formatted_attrs
=
{
t
:
[]
for
t
in
types
}
if
is_all
(
attrs
):
for
typ
in
types
:
if
mode
==
'node'
:
# Handle the case where the nodes of a type have no features
formatted_attrs
[
typ
]
=
list
(
ref_node_feats
.
get
(
typ
,
(
None
,
dict
()))[
1
].
keys
())
elif
mode
==
'edge'
:
# Handle the case where the edges of a type have no features
formatted_attrs
[
typ
]
=
list
(
ref_edge_feats
.
get
(
typ
,
(
None
,
dict
()))[
1
].
keys
())
elif
isinstance
(
attrs
,
dict
):
for
typ
,
v
in
attrs
.
items
():
if
isinstance
(
v
,
str
):
formatted_attrs
[
typ
]
=
[
v
]
elif
isinstance
(
v
,
Iterable
):
formatted_attrs
[
typ
]
=
list
(
v
)
elif
v
is
not
None
:
raise
ValueError
(
'Expected {} attrs for type {} to be str '
'or iterable, got {}'
.
format
(
mode
,
typ
,
type
(
v
)))
elif
attrs
is
not
None
:
raise
ValueError
(
'Expected {} attrs to be of type None or dict,'
'got type {}'
.
format
(
mode
,
type
(
attrs
)))
return
formatted_attrs
node_attrs
=
_init_attrs
(
ref_ntypes
,
node_attrs
,
'node'
)
edge_attrs
=
_init_attrs
(
ref_canonical_etypes
,
edge_attrs
,
'edge'
)
node_frames
=
[]
for
tid
,
typ
in
enumerate
(
ref_ntypes
):
if
len
(
node_attrs
[
typ
])
==
0
:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
node_frames
.
append
(
None
)
else
:
# NOTE: following code will materialize the columns of the input graphs.
cols
=
{
key
:
F
.
cat
([
gr
.
_node_frames
[
tid
][
key
]
for
gr
in
graph_list
if
gr
.
number_of_nodes
(
typ
)
>
0
],
dim
=
0
)
for
key
in
node_attrs
[
typ
]}
node_frames
.
append
(
FrameRef
(
Frame
(
cols
)))
edge_frames
=
[]
for
tid
,
typ
in
enumerate
(
ref_canonical_etypes
):
if
len
(
edge_attrs
[
typ
])
==
0
:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
edge_frames
.
append
(
None
)
else
:
# NOTE: following code will materialize the columns of the input graphs.
cols
=
{
key
:
F
.
cat
([
gr
.
_edge_frames
[
tid
][
key
]
for
gr
in
graph_list
if
gr
.
number_of_edges
(
typ
)
>
0
],
dim
=
0
)
for
key
in
edge_attrs
[
typ
]}
edge_frames
.
append
(
FrameRef
(
Frame
(
cols
)))
# Create graph index for the batched graph
metagraph
=
graph_list
[
0
].
_graph
.
metagraph
batched_index
=
heterograph_index
.
disjoint_union
(
metagraph
,
[
g
.
_graph
for
g
in
graph_list
])
batch_size
=
0
# Store number of nodes/edge based on the id of node/edge types as we need
# to handle both edge type and canonical edge type.
batch_num_nodes
=
[[]
for
_
in
range
(
len
(
ref_ntypes
))]
batch_num_edges
=
[[]
for
_
in
range
(
len
(
ref_etypes
))]
for
grh
in
graph_list
:
if
isinstance
(
grh
,
BatchedDGLHeteroGraph
):
# Handle input graphs that are already batched
batch_size
+=
grh
.
_batch_size
for
ntype_id
in
range
(
len
(
ref_ntypes
)):
batch_num_nodes
[
ntype_id
].
extend
(
grh
.
_batch_num_nodes
[
ntype_id
])
for
etype_id
in
range
(
len
(
ref_etypes
)):
batch_num_edges
[
etype_id
].
extend
(
grh
.
_batch_num_edges
[
etype_id
])
else
:
batch_size
+=
1
for
ntype_id
in
range
(
len
(
ref_ntypes
)):
batch_num_nodes
[
ntype_id
].
append
(
grh
.
_graph
.
number_of_nodes
(
ntype_id
))
for
etype_id
in
range
(
len
(
ref_etypes
)):
batch_num_edges
[
etype_id
].
append
(
grh
.
_graph
.
number_of_edges
(
etype_id
))
return
BatchedDGLHeteroGraph
(
gidx
=
batched_index
,
ntypes
=
ref_ntypes
,
etypes
=
ref_etypes
,
node_frames
=
node_frames
,
edge_frames
=
edge_frames
,
batch_size
=
batch_size
,
batch_num_nodes
=
batch_num_nodes
,
batch_num_edges
=
batch_num_edges
)
tests/compute/test_batched_heterograph.py
View file @
86c243b4
import
dgl
import
dgl
import
backend
as
F
import
backend
as
F
import
unittest
from
dgl.base
import
ALL
from
dgl.base
import
ALL
from
utils
import
parametrize_dtype
from
utils
import
parametrize_dtype
...
@@ -268,8 +269,35 @@ def test_batching_with_zero_nodes_edges(index_dtype):
...
@@ -268,8 +269,35 @@ def test_batching_with_zero_nodes_edges(index_dtype):
g2
.
nodes
[
'u'
].
data
[
'x'
]
=
F
.
tensor
([
1
])
g2
.
nodes
[
'u'
].
data
[
'x'
]
=
F
.
tensor
([
1
])
dgl
.
batch_hetero
([
g1
,
g2
])
dgl
.
batch_hetero
([
g1
,
g2
])
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'cpu'
,
reason
=
"Need gpu for this test"
)
@
parametrize_dtype
def
test_to_device
(
index_dtype
):
g1
=
dgl
.
heterograph
({
(
'user'
,
'plays'
,
'game'
):
[(
0
,
0
),
(
1
,
1
)]
},
index_dtype
=
index_dtype
)
g1
.
nodes
[
'user'
].
data
[
'h1'
]
=
F
.
copy_to
(
F
.
tensor
([[
0.
],
[
1.
]]),
F
.
cpu
())
g1
.
nodes
[
'user'
].
data
[
'h2'
]
=
F
.
copy_to
(
F
.
tensor
([[
3.
],
[
4.
]]),
F
.
cpu
())
g1
.
edges
[
'plays'
].
data
[
'h1'
]
=
F
.
copy_to
(
F
.
tensor
([[
2.
],
[
3.
]]),
F
.
cpu
())
g2
=
dgl
.
heterograph
({
(
'user'
,
'plays'
,
'game'
):
[(
0
,
0
),
(
1
,
0
)]
},
index_dtype
=
index_dtype
)
g2
.
nodes
[
'user'
].
data
[
'h1'
]
=
F
.
copy_to
(
F
.
tensor
([[
1.
],
[
2.
]]),
F
.
cpu
())
g2
.
nodes
[
'user'
].
data
[
'h2'
]
=
F
.
copy_to
(
F
.
tensor
([[
4.
],
[
5.
]]),
F
.
cpu
())
g2
.
edges
[
'plays'
].
data
[
'h1'
]
=
F
.
copy_to
(
F
.
tensor
([[
0.
],
[
1.
]]),
F
.
cpu
())
bg
=
dgl
.
batch_hetero
([
g1
,
g2
])
if
F
.
is_cuda_available
():
bg1
=
bg
.
to
(
F
.
cuda
())
assert
bg1
is
not
None
assert
bg
.
batch_size
==
bg1
.
batch_size
assert
bg
.
batch_num_nodes
(
'user'
)
==
bg1
.
batch_num_nodes
(
'user'
)
assert
bg
.
batch_num_edges
(
'plays'
)
==
bg1
.
batch_num_edges
(
'plays'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_batching_hetero_topology
()
test_batching_hetero_topology
()
test_batching_hetero_and_batched_hetero_topology
()
test_batching_hetero_and_batched_hetero_topology
()
test_batched_features
()
test_batched_features
()
test_batching_with_zero_nodes_edges
()
test_batching_with_zero_nodes_edges
()
# test_to_device()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment