Unverified Commit 0f40c6e4 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Hetero] Replace card with num_nodes


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 1b9bc16b
...@@ -246,9 +246,9 @@ class MovieLens(object): ...@@ -246,9 +246,9 @@ class MovieLens(object):
rrow = rating_row[ridx] rrow = rating_row[ridx]
rcol = rating_col[ridx] rcol = rating_col[ridx]
bg = dgl.bipartite((rrow, rcol), 'user', str(rating), 'movie', bg = dgl.bipartite((rrow, rcol), 'user', str(rating), 'movie',
card=(self._num_user, self._num_movie)) num_nodes=(self._num_user, self._num_movie))
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % str(rating), 'user', rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % str(rating), 'user',
card=(self._num_movie, self._num_user)) num_nodes=(self._num_movie, self._num_user))
rating_graphs.append(bg) rating_graphs.append(bg)
rating_graphs.append(rev_bg) rating_graphs.append(rev_bg)
graph = dgl.hetero_from_relations(rating_graphs) graph = dgl.hetero_from_relations(rating_graphs)
......
...@@ -246,9 +246,9 @@ class MovieLens(object): ...@@ -246,9 +246,9 @@ class MovieLens(object):
rrow = rating_row[ridx] rrow = rating_row[ridx]
rcol = rating_col[ridx] rcol = rating_col[ridx]
bg = dgl.bipartite((rrow, rcol), 'user', str(rating), 'movie', bg = dgl.bipartite((rrow, rcol), 'user', str(rating), 'movie',
card=(self._num_user, self._num_movie)) num_nodes=(self._num_user, self._num_movie))
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % str(rating), 'user', rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % str(rating), 'user',
card=(self._num_movie, self._num_user)) num_nodes=(self._num_movie, self._num_user))
rating_graphs.append(bg) rating_graphs.append(bg)
rating_graphs.append(rev_bg) rating_graphs.append(rev_bg)
graph = dgl.hetero_from_relations(rating_graphs) graph = dgl.hetero_from_relations(rating_graphs)
......
...@@ -9,7 +9,7 @@ from . import heterograph_index ...@@ -9,7 +9,7 @@ from . import heterograph_index
from .heterograph import DGLHeteroGraph, combine_frames from .heterograph import DGLHeteroGraph, combine_frames
from . import graph_index from . import graph_index
from . import utils from . import utils
from .base import NTYPE, ETYPE, NID, EID, DGLError from .base import NTYPE, ETYPE, NID, EID, DGLError, dgl_warning
__all__ = [ __all__ = [
'graph', 'graph',
...@@ -21,8 +21,8 @@ __all__ = [ ...@@ -21,8 +21,8 @@ __all__ = [
'to_networkx', 'to_networkx',
] ]
def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_format='any', def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True,
**kwargs): restrict_format='any', **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -43,9 +43,12 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma ...@@ -43,9 +43,12 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
Node type name. (Default: _N) Node type name. (Default: _N)
etype : str, optional etype : str, optional
Edge type name. (Default: _E) Edge type name. (Default: _E)
card : int, optional num_nodes : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e. Number of nodes in the graph. If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None) the largest node ID plus 1. (Default: None)
card : int, optional
Deprecated (see :attr:`num_nodes`). Cardinality (number of nodes in the graph).
If None, infer from input data, i.e. the largest node ID plus 1. (Default: None)
validate : bool, optional validate : bool, optional
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. (Default: True) some time. (Default: True)
...@@ -109,18 +112,22 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma ...@@ -109,18 +112,22 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
>>> g.canonical_etypes >>> g.canonical_etypes
[('user', 'follows', 'user')] [('user', 'follows', 'user')]
Check if node ids are within cardinality Check if node ids are within num_nodes specified
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True) >>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), num_nodes=2, validate=True)
... ...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2). dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True) >>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), num_nodes=3, validate=True)
Graph(num_nodes=3, num_edges=3, Graph(num_nodes=3, num_edges=3,
ndata_schemes={} ndata_schemes={}
edata_schemes={}) edata_schemes={})
""" """
if card is not None: if card is not None:
urange, vrange = card, card dgl_warning("card will be deprecated, please use num_nodes='{}' instead.")
num_nodes = card
if num_nodes is not None:
urange, vrange = num_nodes, num_nodes
else: else:
urange, vrange = None, None urange, vrange = None, None
if isinstance(data, tuple): if isinstance(data, tuple):
...@@ -141,8 +148,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma ...@@ -141,8 +148,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=True, restrict_forma
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True, def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None,
restrict_format='any', **kwargs): validate=True, restrict_format='any', **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -168,9 +175,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True ...@@ -168,9 +175,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
Edge type name. (Default: _E) Edge type name. (Default: _E)
vtype : str, optional vtype : str, optional
Destination node type name. (Default: _V) Destination node type name. (Default: _V)
card : pair of int, optional num_nodes : 2-tuple of int, optional
Cardinality (number of nodes in the source and destination group). If None, Number of nodes in the source and destination group. If None, infer from input data,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None) i.e. the largest node ID plus 1 for each type. (Default: None)
card : 2-tuple of int, optional
Deprecated (see :attr:`num_nodes`). Cardinality (number of nodes in the source and
destination group). If None, infer from input data, i.e. the largest node ID plus 1
for each type. (Default: None)
validate : bool, optional validate : bool, optional
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. (Default: True) some time. (Default: True)
...@@ -246,12 +257,12 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True ...@@ -246,12 +257,12 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
>>> g.edges() >>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])) (tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality Check if node ids are within num_nodes specified
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True) >>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), num_nodes=(2, 4), validate=True)
... ...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2). dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True) >>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), num_nodes=(3, 4), validate=True)
>>> g >>> g
Graph(num_nodes={'_U': 3, '_V': 4}, Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3}, num_edges={('_U', '_E', '_V'): 3},
...@@ -260,7 +271,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True ...@@ -260,7 +271,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True
if utype == vtype: if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.') raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
if card is not None: if card is not None:
urange, vrange = card dgl_warning("card will be deprecated, please use num_nodes='{}' instead.")
num_nodes = card
if num_nodes is not None:
urange, vrange = num_nodes
else: else:
urange, vrange = None, None urange, vrange = None, None
if isinstance(data, tuple): if isinstance(data, tuple):
...@@ -321,9 +335,9 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -321,9 +335,9 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
the relation graphs. the relation graphs.
>>> # A graph with 4 nodes of type 'user' >>> # A graph with 4 nodes of type 'user'
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', card=4) >>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', num_nodes=4)
>>> # A bipartite graph with 4 nodes of src type ('user') and 2 nodes of dst type ('game') >>> # A bipartite graph with 4 nodes of src type ('user') and 2 nodes of dst type ('game')
>>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game', card=(4, 2)) >>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game', num_nodes=(4, 2))
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game') >>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
>>> print(g) >>> print(g)
...@@ -468,11 +482,11 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -468,11 +482,11 @@ def heterograph(data_dict, num_nodes_dict=None):
elif srctype == dsttype: elif srctype == dsttype:
rel_graphs.append(graph( rel_graphs.append(graph(
data, srctype, etype, data, srctype, etype,
card=num_nodes_dict[srctype], validate=False)) num_nodes=num_nodes_dict[srctype], validate=False))
else: else:
rel_graphs.append(bipartite( rel_graphs.append(bipartite(
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False)) num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))
return hetero_from_relations(rel_graphs, num_nodes_dict) return hetero_from_relations(rel_graphs, num_nodes_dict)
...@@ -625,11 +639,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -625,11 +639,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if stid == dtid: if stid == dtid:
rel_graph = graph( rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
card=ntype_count[stid], validate=False) num_nodes=ntype_count[stid], validate=False)
else: else:
rel_graph = bipartite( rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
card=(ntype_count[stid], ntype_count[dtid]), validate=False) num_nodes=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph) rel_graphs.append(rel_graph)
hg = hetero_from_relations( hg = hetero_from_relations(
...@@ -717,7 +731,7 @@ def to_homo(G): ...@@ -717,7 +731,7 @@ def to_homo(G):
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu())) etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges)) eids.append(F.arange(0, num_edges))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes, validate=False) retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, validate=False)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0) retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0) retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0) retg.edata[ETYPE] = F.cat(etype_ids, 0)
......
...@@ -31,6 +31,6 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'): ...@@ -31,6 +31,6 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'):
rows = F.astype(eids / num_nodes, F.dtype(eids)) rows = F.astype(eids / num_nodes, F.dtype(eids))
cols = F.astype(eids % num_nodes, F.dtype(eids)) cols = F.astype(eids % num_nodes, F.dtype(eids))
g = convert.graph((rows, cols), g = convert.graph((rows, cols),
card=num_nodes, validate=False, num_nodes=num_nodes, validate=False,
restrict_format=restrict_format) restrict_format=restrict_format)
return g return g
...@@ -3831,7 +3831,7 @@ class DGLHeteroGraph(object): ...@@ -3831,7 +3831,7 @@ class DGLHeteroGraph(object):
>>> import torch >>> import torch
>>> import dgl >>> import dgl
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g = dgl.graph([], 'user', 'follows', card=4) >>> g = dgl.graph([], 'user', 'follows', num_nodes=4)
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user') >>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
tensor([1, 2]) tensor([1, 2])
......
...@@ -109,7 +109,7 @@ class RandomWalkNeighborSampler(object): ...@@ -109,7 +109,7 @@ class RandomWalkNeighborSampler(object):
# count the number of visits and pick the K-most frequent neighbors for each node # count the number of visits and pick the K-most frequent neighbors for each node
neighbor_graph = convert.graph( neighbor_graph = convert.graph(
(src, dst), card=self.G.number_of_nodes(self.ntype), ntype=self.ntype) (src, dst), num_nodes=self.G.number_of_nodes(self.ntype), ntype=self.ntype)
neighbor_graph = transform.to_simple(neighbor_graph, return_counts=self.weight_column) neighbor_graph = transform.to_simple(neighbor_graph, return_counts=self.weight_column)
counts = neighbor_graph.edata[self.weight_column] counts = neighbor_graph.edata[self.weight_column]
neighbor_graph = select_topk(neighbor_graph, self.num_neighbors, self.weight_column) neighbor_graph = select_topk(neighbor_graph, self.num_neighbors, self.weight_column)
......
...@@ -654,7 +654,7 @@ def compact_graphs(graphs, always_preserve=None): ...@@ -654,7 +654,7 @@ def compact_graphs(graphs, always_preserve=None):
The following code constructs a bipartite graph with 20 users and 10 games, but The following code constructs a bipartite graph with 20 users and 10 games, but
only user #1 and #3, as well as game #3 and #5, have connections: only user #1 and #3, as well as game #3 and #5, have connections:
>>> g = dgl.bipartite([(1, 3), (3, 5)], 'user', 'plays', 'game', card=(20, 10)) >>> g = dgl.bipartite([(1, 3), (3, 5)], 'user', 'plays', 'game', num_nodes=(20, 10))
The following would compact the graph above to another bipartite graph with only The following would compact the graph above to another bipartite graph with only
two users and two games. two users and two games.
...@@ -676,7 +676,7 @@ def compact_graphs(graphs, always_preserve=None): ...@@ -676,7 +676,7 @@ def compact_graphs(graphs, always_preserve=None):
of the given graphs are removed. So if we compact ``g`` and the following ``g2`` of the given graphs are removed. So if we compact ``g`` and the following ``g2``
graphs together: graphs together:
>>> g2 = dgl.bipartite([(1, 6), (6, 8)], 'user', 'plays', 'game', card=(20, 10)) >>> g2 = dgl.bipartite([(1, 6), (6, 8)], 'user', 'plays', 'game', num_nodes=(20, 10))
>>> (new_g, new_g2), induced_nodes = dgl.compact_graphs([g, g2]) >>> (new_g, new_g2), induced_nodes = dgl.compact_graphs([g, g2])
>>> induced_nodes >>> induced_nodes
{'user': tensor([1, 3, 6]), 'game': tensor([3, 5, 6, 8])} {'user': tensor([1, 3, 6]), 'game': tensor([3, 5, 6, 8])}
......
...@@ -60,6 +60,23 @@ def generate_graph(grad=False): ...@@ -60,6 +60,23 @@ def generate_graph(grad=False):
g.set_e_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer)
return g return g
def test_isolated_nodes():
g = dgl.graph([(0, 1), (1, 2)], num_nodes=5)
assert g.number_of_nodes() == 5
# Test backward compatibility
g = dgl.graph([(0, 1), (1, 2)], card=5)
assert g.number_of_nodes() == 5
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', num_nodes=(5, 7))
assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7
# Test backward compatibility
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', card=(5, 7))
assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7
def test_batch_setter_getter(): def test_batch_setter_getter():
def _pfc(x): def _pfc(x):
return list(F.zerocopy_to_numpy(x)[:,0]) return list(F.zerocopy_to_numpy(x)[:,0])
...@@ -452,8 +469,8 @@ def test_update_all_0deg(): ...@@ -452,8 +469,8 @@ def test_update_all_0deg():
assert F.allclose(new_repr[1:], 2*(2+F.zeros((4,5)))) assert F.allclose(new_repr[1:], 2*(2+F.zeros((4,5))))
assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0)) assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0))
# test#2: graph with no edge # test#2:
g = dgl.graph([], card=5) g = dgl.graph([], num_nodes=5)
g.set_n_initializer(_init2, 'h') g.set_n_initializer(_init2, 'h')
g.ndata['h'] = old_repr g.ndata['h'] = old_repr
g.update_all(_message, _reduce, _apply) g.update_all(_message, _reduce, _apply)
...@@ -592,7 +609,7 @@ def _test_dynamic_addition(): ...@@ -592,7 +609,7 @@ def _test_dynamic_addition():
def test_repr(): def test_repr():
G = dgl.graph([(0,1), (0,2), (1,2)], card=10) G = dgl.graph([(0,1), (0,2), (1,2)], num_nodes=10)
repr_string = G.__repr__() repr_string = G.__repr__()
print(repr_string) print(repr_string)
G.ndata['x'] = F.zeros((10, 5)) G.ndata['x'] = F.zeros((10, 5))
...@@ -773,6 +790,7 @@ def test_issue_1088(): ...@@ -773,6 +790,7 @@ def test_issue_1088():
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y')) g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
if __name__ == '__main__': if __name__ == '__main__':
test_isolated_nodes()
test_nx_conversion() test_nx_conversion()
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
......
...@@ -118,7 +118,7 @@ def test_create(): ...@@ -118,7 +118,7 @@ def test_create():
try: try:
g = dgl.graph( g = dgl.graph(
([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]), ([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),
card=2, num_nodes=2,
validate=True validate=True
) )
except DGLError: except DGLError:
...@@ -131,7 +131,7 @@ def test_create(): ...@@ -131,7 +131,7 @@ def test_create():
try: try:
g = dgl.bipartite( g = dgl.bipartite(
([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]), ([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]),
card=card, num_nodes=card,
validate=True validate=True
) )
except DGLError: except DGLError:
...@@ -720,14 +720,14 @@ def test_to_device(): ...@@ -720,14 +720,14 @@ def test_to_device():
def test_convert_bound(): def test_convert_bound():
def _test_bipartite_bound(data, card): def _test_bipartite_bound(data, card):
try: try:
dgl.bipartite(data, card=card) dgl.bipartite(data, num_nodes=card)
except dgl.DGLError: except dgl.DGLError:
return return
assert False, 'bipartite bound test with wrong uid failed' assert False, 'bipartite bound test with wrong uid failed'
def _test_graph_bound(data, card): def _test_graph_bound(data, card):
try: try:
dgl.graph(data, card=card) dgl.graph(data, num_nodes=card)
except dgl.DGLError: except dgl.DGLError:
return return
assert False, 'graph bound test with wrong uid failed' assert False, 'graph bound test with wrong uid failed'
...@@ -827,7 +827,7 @@ def test_convert(): ...@@ -827,7 +827,7 @@ def test_convert():
assert len(hg.etypes) == 2 assert len(hg.etypes) == 2
# hetero_to_homo test case 2 # hetero_to_homo test case 2
hg = dgl.bipartite([(0, 0), (1, 1)], card=(2, 3)) hg = dgl.bipartite([(0, 0), (1, 1)], num_nodes=(2, 3))
g = dgl.to_homo(hg) g = dgl.to_homo(hg)
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
......
...@@ -152,24 +152,24 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse): ...@@ -152,24 +152,24 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
if reverse: if reverse:
g = dgl.graph([(0,1),(0,2),(0,3),(1,0),(1,2),(1,3),(2,0)], g = dgl.graph([(0,1),(0,2),(0,3),(1,0),(1,2),(1,3),(2,0)],
'user', 'follow', card=card) 'user', 'follow', num_nodes=card)
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32) g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(1,0),(2,1),(2,3)], 'game', 'play', 'user', card=card2) g1 = dgl.bipartite([(0,0),(1,0),(2,1),(2,3)], 'game', 'play', 'user', num_nodes=card2)
g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32) g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
g2 = dgl.bipartite([(0,2),(1,2),(2,2),(0,1),(3,1),(0,0)], 'user', 'liked-by', 'game', card=card2) g2 = dgl.bipartite([(0,2),(1,2),(2,2),(0,1),(3,1),(0,0)], 'user', 'liked-by', 'game', num_nodes=card2)
g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32) g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(0,1),(0,2),(0,3)], 'coin', 'flips', 'user', card=card2) g3 = dgl.bipartite([(0,0),(0,1),(0,2),(0,3)], 'coin', 'flips', 'user', num_nodes=card2)
hg = dgl.hetero_from_relations([g, g1, g2, g3]) hg = dgl.hetero_from_relations([g, g1, g2, g3])
else: else:
g = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)], g = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)],
'user', 'follow', card=card) 'user', 'follow', num_nodes=card)
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32) g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game', card=card2) g1 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game', num_nodes=card2)
g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32) g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
g2 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user', card=card2) g2 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user', num_nodes=card2)
g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32) g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin', card=card2) g3 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin', num_nodes=card2)
hg = dgl.hetero_from_relations([g, g1, g2, g3]) hg = dgl.hetero_from_relations([g, g1, g2, g3])
return g, hg return g, hg
......
...@@ -326,8 +326,8 @@ def test_compact(): ...@@ -326,8 +326,8 @@ def test_compact():
('user', 'likes', 'user'): [(1, 8), (8, 9)]}, ('user', 'likes', 'user'): [(1, 8), (8, 9)]},
{'user': 20, 'game': 10}) {'user': 20, 'game': 10})
g3 = dgl.graph([(0, 1), (1, 2)], card=10, ntype='user') g3 = dgl.graph([(0, 1), (1, 2)], num_nodes=10, ntype='user')
g4 = dgl.graph([(1, 3), (3, 5)], card=10, ntype='user') g4 = dgl.graph([(1, 3), (3, 5)], num_nodes=10, ntype='user')
def _check(g, new_g, induced_nodes): def _check(g, new_g, induced_nodes):
assert g.ntypes == new_g.ntypes assert g.ntypes == new_g.ntypes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment