Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0f40c6e4
Unverified
Commit
0f40c6e4
authored
Mar 20, 2020
by
Mufei Li
Committed by
GitHub
Mar 20, 2020
Browse files
[Hetero] Replace card with num_nodes
Co-authored-by:
Minjie Wang
<
wmjlyjemaine@gmail.com
>
parent
1b9bc16b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
84 additions
and
52 deletions
+84
-52
examples/mxnet/gcmc/data.py
examples/mxnet/gcmc/data.py
+2
-2
examples/pytorch/gcmc/data.py
examples/pytorch/gcmc/data.py
+2
-2
python/dgl/convert.py
python/dgl/convert.py
+39
-25
python/dgl/generators.py
python/dgl/generators.py
+1
-1
python/dgl/heterograph.py
python/dgl/heterograph.py
+1
-1
python/dgl/sampling/pinsage.py
python/dgl/sampling/pinsage.py
+1
-1
python/dgl/transform.py
python/dgl/transform.py
+2
-2
tests/compute/test_hetero_basics.py
tests/compute/test_hetero_basics.py
+21
-3
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+5
-5
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+8
-8
tests/compute/test_transform.py
tests/compute/test_transform.py
+2
-2
No files found.
examples/mxnet/gcmc/data.py
View file @
0f40c6e4
...
@@ -246,9 +246,9 @@ class MovieLens(object):
...
@@ -246,9 +246,9 @@ class MovieLens(object):
rrow
=
rating_row
[
ridx
]
rrow
=
rating_row
[
ridx
]
rcol
=
rating_col
[
ridx
]
rcol
=
rating_col
[
ridx
]
bg
=
dgl
.
bipartite
((
rrow
,
rcol
),
'user'
,
str
(
rating
),
'movie'
,
bg
=
dgl
.
bipartite
((
rrow
,
rcol
),
'user'
,
str
(
rating
),
'movie'
,
card
=
(
self
.
_num_user
,
self
.
_num_movie
))
num_nodes
=
(
self
.
_num_user
,
self
.
_num_movie
))
rev_bg
=
dgl
.
bipartite
((
rcol
,
rrow
),
'movie'
,
'rev-%s'
%
str
(
rating
),
'user'
,
rev_bg
=
dgl
.
bipartite
((
rcol
,
rrow
),
'movie'
,
'rev-%s'
%
str
(
rating
),
'user'
,
card
=
(
self
.
_num_movie
,
self
.
_num_user
))
num_nodes
=
(
self
.
_num_movie
,
self
.
_num_user
))
rating_graphs
.
append
(
bg
)
rating_graphs
.
append
(
bg
)
rating_graphs
.
append
(
rev_bg
)
rating_graphs
.
append
(
rev_bg
)
graph
=
dgl
.
hetero_from_relations
(
rating_graphs
)
graph
=
dgl
.
hetero_from_relations
(
rating_graphs
)
...
...
examples/pytorch/gcmc/data.py
View file @
0f40c6e4
...
@@ -246,9 +246,9 @@ class MovieLens(object):
...
@@ -246,9 +246,9 @@ class MovieLens(object):
rrow
=
rating_row
[
ridx
]
rrow
=
rating_row
[
ridx
]
rcol
=
rating_col
[
ridx
]
rcol
=
rating_col
[
ridx
]
bg
=
dgl
.
bipartite
((
rrow
,
rcol
),
'user'
,
str
(
rating
),
'movie'
,
bg
=
dgl
.
bipartite
((
rrow
,
rcol
),
'user'
,
str
(
rating
),
'movie'
,
card
=
(
self
.
_num_user
,
self
.
_num_movie
))
num_nodes
=
(
self
.
_num_user
,
self
.
_num_movie
))
rev_bg
=
dgl
.
bipartite
((
rcol
,
rrow
),
'movie'
,
'rev-%s'
%
str
(
rating
),
'user'
,
rev_bg
=
dgl
.
bipartite
((
rcol
,
rrow
),
'movie'
,
'rev-%s'
%
str
(
rating
),
'user'
,
card
=
(
self
.
_num_movie
,
self
.
_num_user
))
num_nodes
=
(
self
.
_num_movie
,
self
.
_num_user
))
rating_graphs
.
append
(
bg
)
rating_graphs
.
append
(
bg
)
rating_graphs
.
append
(
rev_bg
)
rating_graphs
.
append
(
rev_bg
)
graph
=
dgl
.
hetero_from_relations
(
rating_graphs
)
graph
=
dgl
.
hetero_from_relations
(
rating_graphs
)
...
...
python/dgl/convert.py
View file @
0f40c6e4
...
@@ -9,7 +9,7 @@ from . import heterograph_index
...
@@ -9,7 +9,7 @@ from . import heterograph_index
from
.heterograph
import
DGLHeteroGraph
,
combine_frames
from
.heterograph
import
DGLHeteroGraph
,
combine_frames
from
.
import
graph_index
from
.
import
graph_index
from
.
import
utils
from
.
import
utils
from
.base
import
NTYPE
,
ETYPE
,
NID
,
EID
,
DGLError
from
.base
import
NTYPE
,
ETYPE
,
NID
,
EID
,
DGLError
,
dgl_warning
__all__
=
[
__all__
=
[
'graph'
,
'graph'
,
...
@@ -21,8 +21,8 @@ __all__ = [
...
@@ -21,8 +21,8 @@ __all__ = [
'to_networkx'
,
'to_networkx'
,
]
]
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
card
=
None
,
validate
=
True
,
restrict_format
=
'any'
,
def
graph
(
data
,
ntype
=
'_N'
,
etype
=
'_E'
,
num_nodes
=
None
,
card
=
None
,
validate
=
True
,
**
kwargs
):
restrict_format
=
'any'
,
**
kwargs
):
"""Create a graph with one type of nodes and edges.
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...
@@ -43,9 +43,12 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
...
@@ -43,9 +43,12 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
Node type name. (Default: _N)
Node type name. (Default: _N)
etype : str, optional
etype : str, optional
Edge type name. (Default: _E)
Edge type name. (Default: _E)
card
: int, optional
num_nodes
: int, optional
Cardinality (n
umber of nodes in the graph
)
. If None, infer from input data, i.e.
N
umber of nodes in the graph. If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None)
the largest node ID plus 1. (Default: None)
card : int, optional
Deprecated (see :attr:`num_nodes`). Cardinality (number of nodes in the graph).
If None, infer from input data, i.e. the largest node ID plus 1. (Default: None)
validate : bool, optional
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
If True, check if node ids are within cardinality, the check process may take
some time. (Default: True)
some time. (Default: True)
...
@@ -109,18 +112,22 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
...
@@ -109,18 +112,22 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
>>> g.canonical_etypes
>>> g.canonical_etypes
[('user', 'follows', 'user')]
[('user', 'follows', 'user')]
Check if node ids are within
cardinality
Check if node ids are within
num_nodes specified
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]),
card
=2, validate=True)
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]),
num_nodes
=2, validate=True)
...
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]),
card
=3, validate=True)
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]),
num_nodes
=3, validate=True)
Graph(num_nodes=3, num_edges=3,
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
ndata_schemes={}
edata_schemes={})
edata_schemes={})
"""
"""
if
card
is
not
None
:
if
card
is
not
None
:
urange
,
vrange
=
card
,
card
dgl_warning
(
"card will be deprecated, please use num_nodes='{}' instead."
)
num_nodes
=
card
if
num_nodes
is
not
None
:
urange
,
vrange
=
num_nodes
,
num_nodes
else
:
else
:
urange
,
vrange
=
None
,
None
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
if
isinstance
(
data
,
tuple
):
...
@@ -141,8 +148,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
...
@@ -141,8 +148,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
else
:
else
:
raise
DGLError
(
'Unsupported graph data type:'
,
type
(
data
))
raise
DGLError
(
'Unsupported graph data type:'
,
type
(
data
))
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
card
=
None
,
validate
=
Tru
e
,
def
bipartite
(
data
,
utype
=
'_U'
,
etype
=
'_E'
,
vtype
=
'_V'
,
num_nodes
=
None
,
card
=
Non
e
,
restrict_format
=
'any'
,
**
kwargs
):
validate
=
True
,
restrict_format
=
'any'
,
**
kwargs
):
"""Create a bipartite graph.
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
The result graph is directed and edges must be from ``utype`` nodes
...
@@ -168,9 +175,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
...
@@ -168,9 +175,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
Edge type name. (Default: _E)
Edge type name. (Default: _E)
vtype : str, optional
vtype : str, optional
Destination node type name. (Default: _V)
Destination node type name. (Default: _V)
card : pair of int, optional
num_nodes : 2-tuple of int, optional
Cardinality (number of nodes in the source and destination group). If None,
Number of nodes in the source and destination group. If None, infer from input data,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
i.e. the largest node ID plus 1 for each type. (Default: None)
card : 2-tuple of int, optional
Deprecated (see :attr:`num_nodes`). Cardinality (number of nodes in the source and
destination group). If None, infer from input data, i.e. the largest node ID plus 1
for each type. (Default: None)
validate : bool, optional
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
If True, check if node ids are within cardinality, the check process may take
some time. (Default: True)
some time. (Default: True)
...
@@ -246,12 +257,12 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
...
@@ -246,12 +257,12 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
>>> g.edges()
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within
cardinality
Check if node ids are within
num_nodes specified
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]),
card
=(2, 4), validate=True)
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]),
num_nodes
=(2, 4), validate=True)
...
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]),
card
=(3, 4), validate=True)
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]),
num_nodes
=(3, 4), validate=True)
>>> g
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
num_edges={('_U', '_E', '_V'): 3},
...
@@ -260,7 +271,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
...
@@ -260,7 +271,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
if
utype
==
vtype
:
if
utype
==
vtype
:
raise
DGLError
(
'utype should not be equal to vtype. Use ``dgl.graph`` instead.'
)
raise
DGLError
(
'utype should not be equal to vtype. Use ``dgl.graph`` instead.'
)
if
card
is
not
None
:
if
card
is
not
None
:
urange
,
vrange
=
card
dgl_warning
(
"card will be deprecated, please use num_nodes='{}' instead."
)
num_nodes
=
card
if
num_nodes
is
not
None
:
urange
,
vrange
=
num_nodes
else
:
else
:
urange
,
vrange
=
None
,
None
urange
,
vrange
=
None
,
None
if
isinstance
(
data
,
tuple
):
if
isinstance
(
data
,
tuple
):
...
@@ -321,9 +335,9 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
...
@@ -321,9 +335,9 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
the relation graphs.
the relation graphs.
>>> # A graph with 4 nodes of type 'user'
>>> # A graph with 4 nodes of type 'user'
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows',
card
=4)
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows',
num_nodes
=4)
>>> # A bipartite graph with 4 nodes of src type ('user') and 2 nodes of dst type ('game')
>>> # A bipartite graph with 4 nodes of src type ('user') and 2 nodes of dst type ('game')
>>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game',
card
=(4, 2))
>>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game',
num_nodes
=(4, 2))
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
>>> print(g)
>>> print(g)
...
@@ -468,11 +482,11 @@ def heterograph(data_dict, num_nodes_dict=None):
...
@@ -468,11 +482,11 @@ def heterograph(data_dict, num_nodes_dict=None):
elif
srctype
==
dsttype
:
elif
srctype
==
dsttype
:
rel_graphs
.
append
(
graph
(
rel_graphs
.
append
(
graph
(
data
,
srctype
,
etype
,
data
,
srctype
,
etype
,
card
=
num_nodes_dict
[
srctype
],
validate
=
False
))
num_nodes
=
num_nodes_dict
[
srctype
],
validate
=
False
))
else
:
else
:
rel_graphs
.
append
(
bipartite
(
rel_graphs
.
append
(
bipartite
(
data
,
srctype
,
etype
,
dsttype
,
data
,
srctype
,
etype
,
dsttype
,
card
=
(
num_nodes_dict
[
srctype
],
num_nodes_dict
[
dsttype
]),
validate
=
False
))
num_nodes
=
(
num_nodes_dict
[
srctype
],
num_nodes_dict
[
dsttype
]),
validate
=
False
))
return
hetero_from_relations
(
rel_graphs
,
num_nodes_dict
)
return
hetero_from_relations
(
rel_graphs
,
num_nodes_dict
)
...
@@ -625,11 +639,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
...
@@ -625,11 +639,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if
stid
==
dtid
:
if
stid
==
dtid
:
rel_graph
=
graph
(
rel_graph
=
graph
(
(
src_of_etype
,
dst_of_etype
),
ntypes
[
stid
],
etypes
[
etid
],
(
src_of_etype
,
dst_of_etype
),
ntypes
[
stid
],
etypes
[
etid
],
card
=
ntype_count
[
stid
],
validate
=
False
)
num_nodes
=
ntype_count
[
stid
],
validate
=
False
)
else
:
else
:
rel_graph
=
bipartite
(
rel_graph
=
bipartite
(
(
src_of_etype
,
dst_of_etype
),
ntypes
[
stid
],
etypes
[
etid
],
ntypes
[
dtid
],
(
src_of_etype
,
dst_of_etype
),
ntypes
[
stid
],
etypes
[
etid
],
ntypes
[
dtid
],
card
=
(
ntype_count
[
stid
],
ntype_count
[
dtid
]),
validate
=
False
)
num_nodes
=
(
ntype_count
[
stid
],
ntype_count
[
dtid
]),
validate
=
False
)
rel_graphs
.
append
(
rel_graph
)
rel_graphs
.
append
(
rel_graph
)
hg
=
hetero_from_relations
(
hg
=
hetero_from_relations
(
...
@@ -717,7 +731,7 @@ def to_homo(G):
...
@@ -717,7 +731,7 @@ def to_homo(G):
etype_ids
.
append
(
F
.
full_1d
(
num_edges
,
etype_id
,
F
.
int64
,
F
.
cpu
()))
etype_ids
.
append
(
F
.
full_1d
(
num_edges
,
etype_id
,
F
.
int64
,
F
.
cpu
()))
eids
.
append
(
F
.
arange
(
0
,
num_edges
))
eids
.
append
(
F
.
arange
(
0
,
num_edges
))
retg
=
graph
((
F
.
cat
(
srcs
,
0
),
F
.
cat
(
dsts
,
0
)),
card
=
total_num_nodes
,
validate
=
False
)
retg
=
graph
((
F
.
cat
(
srcs
,
0
),
F
.
cat
(
dsts
,
0
)),
num_nodes
=
total_num_nodes
,
validate
=
False
)
retg
.
ndata
[
NTYPE
]
=
F
.
cat
(
ntype_ids
,
0
)
retg
.
ndata
[
NTYPE
]
=
F
.
cat
(
ntype_ids
,
0
)
retg
.
ndata
[
NID
]
=
F
.
cat
(
nids
,
0
)
retg
.
ndata
[
NID
]
=
F
.
cat
(
nids
,
0
)
retg
.
edata
[
ETYPE
]
=
F
.
cat
(
etype_ids
,
0
)
retg
.
edata
[
ETYPE
]
=
F
.
cat
(
etype_ids
,
0
)
...
...
python/dgl/generators.py
View file @
0f40c6e4
...
@@ -31,6 +31,6 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'):
...
@@ -31,6 +31,6 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'):
rows
=
F
.
astype
(
eids
/
num_nodes
,
F
.
dtype
(
eids
))
rows
=
F
.
astype
(
eids
/
num_nodes
,
F
.
dtype
(
eids
))
cols
=
F
.
astype
(
eids
%
num_nodes
,
F
.
dtype
(
eids
))
cols
=
F
.
astype
(
eids
%
num_nodes
,
F
.
dtype
(
eids
))
g
=
convert
.
graph
((
rows
,
cols
),
g
=
convert
.
graph
((
rows
,
cols
),
card
=
num_nodes
,
validate
=
False
,
num_nodes
=
num_nodes
,
validate
=
False
,
restrict_format
=
restrict_format
)
restrict_format
=
restrict_format
)
return
g
return
g
python/dgl/heterograph.py
View file @
0f40c6e4
...
@@ -3831,7 +3831,7 @@ class DGLHeteroGraph(object):
...
@@ -3831,7 +3831,7 @@ class DGLHeteroGraph(object):
>>> import torch
>>> import torch
>>> import dgl
>>> import dgl
>>> import dgl.function as fn
>>> import dgl.function as fn
>>> g = dgl.graph([], 'user', 'follows',
card
=4)
>>> g = dgl.graph([], 'user', 'follows',
num_nodes
=4)
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
tensor([1, 2])
tensor([1, 2])
...
...
python/dgl/sampling/pinsage.py
View file @
0f40c6e4
...
@@ -109,7 +109,7 @@ class RandomWalkNeighborSampler(object):
...
@@ -109,7 +109,7 @@ class RandomWalkNeighborSampler(object):
# count the number of visits and pick the K-most frequent neighbors for each node
# count the number of visits and pick the K-most frequent neighbors for each node
neighbor_graph
=
convert
.
graph
(
neighbor_graph
=
convert
.
graph
(
(
src
,
dst
),
card
=
self
.
G
.
number_of_nodes
(
self
.
ntype
),
ntype
=
self
.
ntype
)
(
src
,
dst
),
num_nodes
=
self
.
G
.
number_of_nodes
(
self
.
ntype
),
ntype
=
self
.
ntype
)
neighbor_graph
=
transform
.
to_simple
(
neighbor_graph
,
return_counts
=
self
.
weight_column
)
neighbor_graph
=
transform
.
to_simple
(
neighbor_graph
,
return_counts
=
self
.
weight_column
)
counts
=
neighbor_graph
.
edata
[
self
.
weight_column
]
counts
=
neighbor_graph
.
edata
[
self
.
weight_column
]
neighbor_graph
=
select_topk
(
neighbor_graph
,
self
.
num_neighbors
,
self
.
weight_column
)
neighbor_graph
=
select_topk
(
neighbor_graph
,
self
.
num_neighbors
,
self
.
weight_column
)
...
...
python/dgl/transform.py
View file @
0f40c6e4
...
@@ -654,7 +654,7 @@ def compact_graphs(graphs, always_preserve=None):
...
@@ -654,7 +654,7 @@ def compact_graphs(graphs, always_preserve=None):
The following code constructs a bipartite graph with 20 users and 10 games, but
The following code constructs a bipartite graph with 20 users and 10 games, but
only user #1 and #3, as well as game #3 and #5, have connections:
only user #1 and #3, as well as game #3 and #5, have connections:
>>> g = dgl.bipartite([(1, 3), (3, 5)], 'user', 'plays', 'game',
card
=(20, 10))
>>> g = dgl.bipartite([(1, 3), (3, 5)], 'user', 'plays', 'game',
num_nodes
=(20, 10))
The following would compact the graph above to another bipartite graph with only
The following would compact the graph above to another bipartite graph with only
two users and two games.
two users and two games.
...
@@ -676,7 +676,7 @@ def compact_graphs(graphs, always_preserve=None):
...
@@ -676,7 +676,7 @@ def compact_graphs(graphs, always_preserve=None):
of the given graphs are removed. So if we compact ``g`` and the following ``g2``
of the given graphs are removed. So if we compact ``g`` and the following ``g2``
graphs together:
graphs together:
>>> g2 = dgl.bipartite([(1, 6), (6, 8)], 'user', 'plays', 'game',
card
=(20, 10))
>>> g2 = dgl.bipartite([(1, 6), (6, 8)], 'user', 'plays', 'game',
num_nodes
=(20, 10))
>>> (new_g, new_g2), induced_nodes = dgl.compact_graphs([g, g2])
>>> (new_g, new_g2), induced_nodes = dgl.compact_graphs([g, g2])
>>> induced_nodes
>>> induced_nodes
{'user': tensor([1, 3, 6]), 'game': tensor([3, 5, 6, 8])}
{'user': tensor([1, 3, 6]), 'game': tensor([3, 5, 6, 8])}
...
...
tests/compute/test_hetero_basics.py
View file @
0f40c6e4
...
@@ -60,6 +60,23 @@ def generate_graph(grad=False):
...
@@ -60,6 +60,23 @@ def generate_graph(grad=False):
g
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
g
.
set_e_initializer
(
dgl
.
init
.
zero_initializer
)
return
g
return
g
def
test_isolated_nodes
():
g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
num_nodes
=
5
)
assert
g
.
number_of_nodes
()
==
5
# Test backward compatibility
g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
card
=
5
)
assert
g
.
number_of_nodes
()
==
5
g
=
dgl
.
bipartite
([(
0
,
2
),
(
0
,
3
),
(
1
,
2
)],
'user'
,
'plays'
,
'game'
,
num_nodes
=
(
5
,
7
))
assert
g
.
number_of_nodes
(
'user'
)
==
5
assert
g
.
number_of_nodes
(
'game'
)
==
7
# Test backward compatibility
g
=
dgl
.
bipartite
([(
0
,
2
),
(
0
,
3
),
(
1
,
2
)],
'user'
,
'plays'
,
'game'
,
card
=
(
5
,
7
))
assert
g
.
number_of_nodes
(
'user'
)
==
5
assert
g
.
number_of_nodes
(
'game'
)
==
7
def
test_batch_setter_getter
():
def
test_batch_setter_getter
():
def
_pfc
(
x
):
def
_pfc
(
x
):
return
list
(
F
.
zerocopy_to_numpy
(
x
)[:,
0
])
return
list
(
F
.
zerocopy_to_numpy
(
x
)[:,
0
])
...
@@ -452,8 +469,8 @@ def test_update_all_0deg():
...
@@ -452,8 +469,8 @@ def test_update_all_0deg():
assert
F
.
allclose
(
new_repr
[
1
:],
2
*
(
2
+
F
.
zeros
((
4
,
5
))))
assert
F
.
allclose
(
new_repr
[
1
:],
2
*
(
2
+
F
.
zeros
((
4
,
5
))))
assert
F
.
allclose
(
new_repr
[
0
],
2
*
F
.
sum
(
old_repr
,
0
))
assert
F
.
allclose
(
new_repr
[
0
],
2
*
F
.
sum
(
old_repr
,
0
))
# test#2:
graph with no edge
# test#2:
g
=
dgl
.
graph
([],
card
=
5
)
g
=
dgl
.
graph
([],
num_nodes
=
5
)
g
.
set_n_initializer
(
_init2
,
'h'
)
g
.
set_n_initializer
(
_init2
,
'h'
)
g
.
ndata
[
'h'
]
=
old_repr
g
.
ndata
[
'h'
]
=
old_repr
g
.
update_all
(
_message
,
_reduce
,
_apply
)
g
.
update_all
(
_message
,
_reduce
,
_apply
)
...
@@ -592,7 +609,7 @@ def _test_dynamic_addition():
...
@@ -592,7 +609,7 @@ def _test_dynamic_addition():
def
test_repr
():
def
test_repr
():
G
=
dgl
.
graph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)],
card
=
10
)
G
=
dgl
.
graph
([(
0
,
1
),
(
0
,
2
),
(
1
,
2
)],
num_nodes
=
10
)
repr_string
=
G
.
__repr__
()
repr_string
=
G
.
__repr__
()
print
(
repr_string
)
print
(
repr_string
)
G
.
ndata
[
'x'
]
=
F
.
zeros
((
10
,
5
))
G
.
ndata
[
'x'
]
=
F
.
zeros
((
10
,
5
))
...
@@ -773,6 +790,7 @@ def test_issue_1088():
...
@@ -773,6 +790,7 @@ def test_issue_1088():
g
.
update_all
(
fn
.
copy_u
(
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'y'
))
g
.
update_all
(
fn
.
copy_u
(
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'y'
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_isolated_nodes
()
test_nx_conversion
()
test_nx_conversion
()
test_batch_setter_getter
()
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_setter_autograd
()
...
...
tests/compute/test_heterograph.py
View file @
0f40c6e4
...
@@ -118,7 +118,7 @@ def test_create():
...
@@ -118,7 +118,7 @@ def test_create():
try
:
try
:
g
=
dgl
.
graph
(
g
=
dgl
.
graph
(
([
0
,
0
,
0
,
1
,
1
,
2
],
[
0
,
1
,
2
,
0
,
1
,
2
]),
([
0
,
0
,
0
,
1
,
1
,
2
],
[
0
,
1
,
2
,
0
,
1
,
2
]),
card
=
2
,
num_nodes
=
2
,
validate
=
True
validate
=
True
)
)
except
DGLError
:
except
DGLError
:
...
@@ -131,7 +131,7 @@ def test_create():
...
@@ -131,7 +131,7 @@ def test_create():
try
:
try
:
g
=
dgl
.
bipartite
(
g
=
dgl
.
bipartite
(
([
0
,
0
,
1
,
1
,
2
],
[
1
,
1
,
2
,
2
,
3
]),
([
0
,
0
,
1
,
1
,
2
],
[
1
,
1
,
2
,
2
,
3
]),
card
=
card
,
num_nodes
=
card
,
validate
=
True
validate
=
True
)
)
except
DGLError
:
except
DGLError
:
...
@@ -720,14 +720,14 @@ def test_to_device():
...
@@ -720,14 +720,14 @@ def test_to_device():
def
test_convert_bound
():
def
test_convert_bound
():
def
_test_bipartite_bound
(
data
,
card
):
def
_test_bipartite_bound
(
data
,
card
):
try
:
try
:
dgl
.
bipartite
(
data
,
card
=
card
)
dgl
.
bipartite
(
data
,
num_nodes
=
card
)
except
dgl
.
DGLError
:
except
dgl
.
DGLError
:
return
return
assert
False
,
'bipartite bound test with wrong uid failed'
assert
False
,
'bipartite bound test with wrong uid failed'
def
_test_graph_bound
(
data
,
card
):
def
_test_graph_bound
(
data
,
card
):
try
:
try
:
dgl
.
graph
(
data
,
card
=
card
)
dgl
.
graph
(
data
,
num_nodes
=
card
)
except
dgl
.
DGLError
:
except
dgl
.
DGLError
:
return
return
assert
False
,
'graph bound test with wrong uid failed'
assert
False
,
'graph bound test with wrong uid failed'
...
@@ -827,7 +827,7 @@ def test_convert():
...
@@ -827,7 +827,7 @@ def test_convert():
assert
len
(
hg
.
etypes
)
==
2
assert
len
(
hg
.
etypes
)
==
2
# hetero_to_homo test case 2
# hetero_to_homo test case 2
hg
=
dgl
.
bipartite
([(
0
,
0
),
(
1
,
1
)],
card
=
(
2
,
3
))
hg
=
dgl
.
bipartite
([(
0
,
0
),
(
1
,
1
)],
num_nodes
=
(
2
,
3
))
g
=
dgl
.
to_homo
(
hg
)
g
=
dgl
.
to_homo
(
hg
)
assert
g
.
number_of_nodes
()
==
5
assert
g
.
number_of_nodes
()
==
5
...
...
tests/compute/test_sampling.py
View file @
0f40c6e4
...
@@ -152,24 +152,24 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
...
@@ -152,24 +152,24 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
if
reverse
:
if
reverse
:
g
=
dgl
.
graph
([(
0
,
1
),(
0
,
2
),(
0
,
3
),(
1
,
0
),(
1
,
2
),(
1
,
3
),(
2
,
0
)],
g
=
dgl
.
graph
([(
0
,
1
),(
0
,
2
),(
0
,
3
),(
1
,
0
),(
1
,
2
),(
1
,
3
),(
2
,
0
)],
'user'
,
'follow'
,
card
=
card
)
'user'
,
'follow'
,
num_nodes
=
card
)
g
.
edata
[
'prob'
]
=
F
.
tensor
([.
5
,
.
5
,
0.
,
.
5
,
.
5
,
0.
,
1.
],
dtype
=
F
.
float32
)
g
.
edata
[
'prob'
]
=
F
.
tensor
([.
5
,
.
5
,
0.
,
.
5
,
.
5
,
0.
,
1.
],
dtype
=
F
.
float32
)
g1
=
dgl
.
bipartite
([(
0
,
0
),(
1
,
0
),(
2
,
1
),(
2
,
3
)],
'game'
,
'play'
,
'user'
,
card
=
card2
)
g1
=
dgl
.
bipartite
([(
0
,
0
),(
1
,
0
),(
2
,
1
),(
2
,
3
)],
'game'
,
'play'
,
'user'
,
num_nodes
=
card2
)
g1
.
edata
[
'prob'
]
=
F
.
tensor
([.
8
,
.
5
,
.
5
,
.
5
],
dtype
=
F
.
float32
)
g1
.
edata
[
'prob'
]
=
F
.
tensor
([.
8
,
.
5
,
.
5
,
.
5
],
dtype
=
F
.
float32
)
g2
=
dgl
.
bipartite
([(
0
,
2
),(
1
,
2
),(
2
,
2
),(
0
,
1
),(
3
,
1
),(
0
,
0
)],
'user'
,
'liked-by'
,
'game'
,
card
=
card2
)
g2
=
dgl
.
bipartite
([(
0
,
2
),(
1
,
2
),(
2
,
2
),(
0
,
1
),(
3
,
1
),(
0
,
0
)],
'user'
,
'liked-by'
,
'game'
,
num_nodes
=
card2
)
g2
.
edata
[
'prob'
]
=
F
.
tensor
([.
3
,
.
5
,
.
2
,
.
5
,
.
1
,
.
1
],
dtype
=
F
.
float32
)
g2
.
edata
[
'prob'
]
=
F
.
tensor
([.
3
,
.
5
,
.
2
,
.
5
,
.
1
,
.
1
],
dtype
=
F
.
float32
)
g3
=
dgl
.
bipartite
([(
0
,
0
),(
0
,
1
),(
0
,
2
),(
0
,
3
)],
'coin'
,
'flips'
,
'user'
,
card
=
card2
)
g3
=
dgl
.
bipartite
([(
0
,
0
),(
0
,
1
),(
0
,
2
),(
0
,
3
)],
'coin'
,
'flips'
,
'user'
,
num_nodes
=
card2
)
hg
=
dgl
.
hetero_from_relations
([
g
,
g1
,
g2
,
g3
])
hg
=
dgl
.
hetero_from_relations
([
g
,
g1
,
g2
,
g3
])
else
:
else
:
g
=
dgl
.
graph
([(
1
,
0
),(
2
,
0
),(
3
,
0
),(
0
,
1
),(
2
,
1
),(
3
,
1
),(
0
,
2
)],
g
=
dgl
.
graph
([(
1
,
0
),(
2
,
0
),(
3
,
0
),(
0
,
1
),(
2
,
1
),(
3
,
1
),(
0
,
2
)],
'user'
,
'follow'
,
card
=
card
)
'user'
,
'follow'
,
num_nodes
=
card
)
g
.
edata
[
'prob'
]
=
F
.
tensor
([.
5
,
.
5
,
0.
,
.
5
,
.
5
,
0.
,
1.
],
dtype
=
F
.
float32
)
g
.
edata
[
'prob'
]
=
F
.
tensor
([.
5
,
.
5
,
0.
,
.
5
,
.
5
,
0.
,
1.
],
dtype
=
F
.
float32
)
g1
=
dgl
.
bipartite
([(
0
,
0
),(
0
,
1
),(
1
,
2
),(
3
,
2
)],
'user'
,
'play'
,
'game'
,
card
=
card2
)
g1
=
dgl
.
bipartite
([(
0
,
0
),(
0
,
1
),(
1
,
2
),(
3
,
2
)],
'user'
,
'play'
,
'game'
,
num_nodes
=
card2
)
g1
.
edata
[
'prob'
]
=
F
.
tensor
([.
8
,
.
5
,
.
5
,
.
5
],
dtype
=
F
.
float32
)
g1
.
edata
[
'prob'
]
=
F
.
tensor
([.
8
,
.
5
,
.
5
,
.
5
],
dtype
=
F
.
float32
)
g2
=
dgl
.
bipartite
([(
2
,
0
),(
2
,
1
),(
2
,
2
),(
1
,
0
),(
1
,
3
),(
0
,
0
)],
'game'
,
'liked-by'
,
'user'
,
card
=
card2
)
g2
=
dgl
.
bipartite
([(
2
,
0
),(
2
,
1
),(
2
,
2
),(
1
,
0
),(
1
,
3
),(
0
,
0
)],
'game'
,
'liked-by'
,
'user'
,
num_nodes
=
card2
)
g2
.
edata
[
'prob'
]
=
F
.
tensor
([.
3
,
.
5
,
.
2
,
.
5
,
.
1
,
.
1
],
dtype
=
F
.
float32
)
g2
.
edata
[
'prob'
]
=
F
.
tensor
([.
3
,
.
5
,
.
2
,
.
5
,
.
1
,
.
1
],
dtype
=
F
.
float32
)
g3
=
dgl
.
bipartite
([(
0
,
0
),(
1
,
0
),(
2
,
0
),(
3
,
0
)],
'user'
,
'flips'
,
'coin'
,
card
=
card2
)
g3
=
dgl
.
bipartite
([(
0
,
0
),(
1
,
0
),(
2
,
0
),(
3
,
0
)],
'user'
,
'flips'
,
'coin'
,
num_nodes
=
card2
)
hg
=
dgl
.
hetero_from_relations
([
g
,
g1
,
g2
,
g3
])
hg
=
dgl
.
hetero_from_relations
([
g
,
g1
,
g2
,
g3
])
return
g
,
hg
return
g
,
hg
...
...
tests/compute/test_transform.py
View file @
0f40c6e4
...
@@ -326,8 +326,8 @@ def test_compact():
...
@@ -326,8 +326,8 @@ def test_compact():
(
'user'
,
'likes'
,
'user'
):
[(
1
,
8
),
(
8
,
9
)]},
(
'user'
,
'likes'
,
'user'
):
[(
1
,
8
),
(
8
,
9
)]},
{
'user'
:
20
,
'game'
:
10
})
{
'user'
:
20
,
'game'
:
10
})
g3
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
card
=
10
,
ntype
=
'user'
)
g3
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
num_nodes
=
10
,
ntype
=
'user'
)
g4
=
dgl
.
graph
([(
1
,
3
),
(
3
,
5
)],
card
=
10
,
ntype
=
'user'
)
g4
=
dgl
.
graph
([(
1
,
3
),
(
3
,
5
)],
num_nodes
=
10
,
ntype
=
'user'
)
def
_check
(
g
,
new_g
,
induced_nodes
):
def
_check
(
g
,
new_g
,
induced_nodes
):
assert
g
.
ntypes
==
new_g
.
ntypes
assert
g
.
ntypes
==
new_g
.
ntypes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment