Unverified Commit dc8ca88e authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Refactor] Explicit dtype for HeteroGraph (#1467)



* 111

* 111

* lint

* lint

* lint

* lint

* fix

* lint

* try

* fix

* lint

* lint

* test

* fix

* ttt

* test

* fix

* fix

* fix

* mxnet

* 111

* fix 64bits computation

* pylint

* roll back

* fix

* lint

* fix hetero_from_relations

* remove index_dtype in to_homo and to_hetero

* fix

* fix

* fix

* fix

* remove default

* fix

* lint

* fix

* fix error message

* fix error

* lint

* macro dispatch

* try

* lint

* remove nbits

* error message

* fix

* fix

* lint

* lint

* lint

* fix

* lint

* fix

* fix random walk

* lint

* lint

* fix

* fix

* fix

* lint

* fix

* lint
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent de34e15a
...@@ -110,7 +110,7 @@ def schedule_recv(graph, ...@@ -110,7 +110,7 @@ def schedule_recv(graph,
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
# sort and unique the argument # sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes, graph.gidx.dtype)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# reduce # reduce
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid), reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
...@@ -161,7 +161,7 @@ def schedule_snr(graph, ...@@ -161,7 +161,7 @@ def schedule_snr(graph,
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes, graph.gidx.dtype)
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
...@@ -216,13 +216,13 @@ def schedule_update_all(graph, ...@@ -216,13 +216,13 @@ def schedule_update_all(graph,
if graph.num_edges() == 0: if graph.num_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes # All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph.num_dst())) nodes = utils.toindex(slice(0, graph.num_dst()), graph.gidx.dtype)
schedule_apply_nodes(nodes, apply_func, graph.dstframe, schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe, inplace=False, outframe=outframe,
ntype=graph.canonical_etype[-1]) ntype=graph.canonical_etype[-1])
else: else:
eid = utils.toindex(slice(0, graph.num_edges())) # ALL eid = utils.toindex(slice(0, graph.num_edges()), graph.gidx.dtype) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL recv_nodes = utils.toindex(slice(0, graph.num_dst()), graph.gidx.dtype) # ALL
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
...@@ -484,8 +484,9 @@ def schedule_pull(graph, ...@@ -484,8 +484,9 @@ def schedule_pull(graph,
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace, schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace,
outframe, ntype=graph.canonical_etype[-1]) outframe, ntype=graph.canonical_etype[-1])
else: else:
# TODO(Allen): Change operation to dgl operation
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes, graph.gidx.dtype)
# create vars # create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf') var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf') var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
...@@ -953,7 +954,7 @@ def _gen_send_reduce( ...@@ -953,7 +954,7 @@ def _gen_send_reduce(
return var_out return var_out
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) mid = utils.toindex(slice(0, len(var_v.data)), var_v.data.dtype)
db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data, db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf, reduce_nodes, var_dst_nf, var_mf,
var_out, ntype=canonical_etype[-1]) var_out, ntype=canonical_etype[-1])
......
...@@ -463,6 +463,7 @@ def metapath_reachable_graph(g, metapath): ...@@ -463,6 +463,7 @@ def metapath_reachable_graph(g, metapath):
A homogeneous or bipartite graph. A homogeneous or bipartite graph.
""" """
adj = 1 adj = 1
index_dtype = g._idtype_str
for etype in metapath: for etype in metapath:
adj = adj * g.adj(etype=etype, scipy_fmt='csr', transpose=True) adj = adj * g.adj(etype=etype, scipy_fmt='csr', transpose=True)
...@@ -471,9 +472,9 @@ def metapath_reachable_graph(g, metapath): ...@@ -471,9 +472,9 @@ def metapath_reachable_graph(g, metapath):
dsttype = g.to_canonical_etype(metapath[-1])[2] dsttype = g.to_canonical_etype(metapath[-1])[2]
if srctype == dsttype: if srctype == dsttype:
assert adj.shape[0] == adj.shape[1] assert adj.shape[0] == adj.shape[1]
new_g = graph(adj, ntype=srctype) new_g = graph(adj, ntype=srctype, index_dtype=index_dtype)
else: else:
new_g = bipartite(adj, utype=srctype, vtype=dsttype) new_g = bipartite(adj, utype=srctype, vtype=dsttype, index_dtype=index_dtype)
for key, value in g.nodes[srctype].data.items(): for key, value in g.nodes[srctype].data.items():
new_g.nodes[srctype].data[key] = value new_g.nodes[srctype].data[key] = value
...@@ -744,14 +745,16 @@ def compact_graphs(graphs, always_preserve=None): ...@@ -744,14 +745,16 @@ def compact_graphs(graphs, always_preserve=None):
# Ensure the node types are ordered the same. # Ensure the node types are ordered the same.
# TODO(BarclayII): we ideally need to remove this constraint. # TODO(BarclayII): we ideally need to remove this constraint.
ntypes = graphs[0].ntypes ntypes = graphs[0].ntypes
graph_dtype = graphs[0]._graph.dtype() graph_dtype = graphs[0]._idtype_str
graph_ctx = graphs[0]._graph.ctx() graph_ctx = graphs[0]._graph.ctx()
for g in graphs: for g in graphs:
assert ntypes == g.ntypes, \ assert ntypes == g.ntypes, \
("All graphs should have the same node types in the same order, got %s and %s" % ("All graphs should have the same node types in the same order, got %s and %s" %
ntypes, g.ntypes) ntypes, g.ntypes)
assert graph_dtype == g._graph.dtype(), "Graph data type mismatch" assert graph_dtype == g._idtype_str, "Expect graph data type to be {}, but got {}".format(
assert graph_ctx == g._graph.ctx(), "Graph device mismatch" graph_dtype, g._idtype_str)
assert graph_ctx == g._graph.ctx(), "Expect graph device to be {}, but got {}".format(
graph_ctx, g._graph.ctx())
# Process the dictionary or tensor of "always preserve" nodes # Process the dictionary or tensor of "always preserve" nodes
if always_preserve is None: if always_preserve is None:
...@@ -919,7 +922,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -919,7 +922,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
if nodes is not None: if nodes is not None:
dst_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes)) dst_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes))
else: else:
dst_nodes_nd.append(nd.NULL) dst_nodes_nd.append(nd.NULL[g._idtype_str])
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock( new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src) g._graph, dst_nodes_nd, include_dst_in_src)
...@@ -935,7 +938,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -935,7 +938,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype] new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype]
else: else:
# For empty dst node sets, still create empty mapping arrays. # For empty dst node sets, still create empty mapping arrays.
new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=F.int64) new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=g.idtype)
for i, canonical_etype in enumerate(g.canonical_etypes): for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data) induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data)
...@@ -970,8 +973,12 @@ def remove_edges(g, edge_ids): ...@@ -970,8 +973,12 @@ def remove_edges(g, edge_ids):
"Graph has more than one edge type; specify a dict for edge_id instead.") "Graph has more than one edge type; specify a dict for edge_id instead.")
edge_ids = {g.canonical_etypes[0]: edge_ids} edge_ids = {g.canonical_etypes[0]: edge_ids}
edge_ids_nd = [nd.NULL] * len(g.etypes) edge_ids_nd = [nd.NULL[g._idtype_str]] * len(g.etypes)
for key, value in edge_ids.items(): for key, value in edge_ids.items():
if value.dtype != g.idtype:
# if didn't check, this function still works, but returns wrong result
raise utils.InconsistentDtypeException("Expect edge id tensors({}) to have \
the same index type as graph({})".format(value.dtype, g.idtype))
edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value) edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value)
new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd) new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd)
...@@ -1018,9 +1025,9 @@ def in_subgraph(g, nodes): ...@@ -1018,9 +1025,9 @@ def in_subgraph(g, nodes):
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in g.ntypes:
if ntype in nodes: if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor()) nodes_all_types.append(utils.toindex(nodes[ntype], g._idtype_str).todgltensor())
else: else:
nodes_all_types.append(nd.array([], ctx=nd.cpu())) nodes_all_types.append(nd.NULL[g._idtype_str])
subgidx = _CAPI_DGLInSubgraph(g._graph, nodes_all_types) subgidx = _CAPI_DGLInSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
...@@ -1057,9 +1064,9 @@ def out_subgraph(g, nodes): ...@@ -1057,9 +1064,9 @@ def out_subgraph(g, nodes):
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in g.ntypes:
if ntype in nodes: if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor()) nodes_all_types.append(utils.toindex(nodes[ntype], g._idtype_str).todgltensor())
else: else:
nodes_all_types.append(nd.array([], ctx=nd.cpu())) nodes_all_types.append(nd.NULL[g._idtype_str])
subgidx = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types) subgidx = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
...@@ -1135,7 +1142,7 @@ def to_simple(g, return_counts='count', writeback_mapping=None): ...@@ -1135,7 +1142,7 @@ def to_simple(g, return_counts='count', writeback_mapping=None):
def as_heterograph(g, ntype='_U', etype='_E'): def as_heterograph(g, ntype='_U', etype='_E'):
"""Convert a DGLGraph to a DGLHeteroGraph with one node and edge type. """Convert a DGLGraph to a DGLHeteroGraph with one node and edge type.
Node and edge features are preserved. Node and edge features are preserved. Returns 64 bits graph
Parameters Parameters
---------- ----------
......
...@@ -9,9 +9,19 @@ from .base import DGLError ...@@ -9,9 +9,19 @@ from .base import DGLError
from . import backend as F from . import backend as F
from . import ndarray as nd from . import ndarray as nd
class InconsistentDtypeException(DGLError):
"""Exception class for inconsistent dtype between graph and tensor"""
def __init__(self, msg='', *args, **kwargs): #pylint: disable=W1113
prefix_message = 'DGL now requires the input tensor to have\
the same dtype as the graph index\'s dtype(which you can get by g.idype). '
super().__init__(prefix_message + msg, *args, **kwargs)
class Index(object): class Index(object):
"""Index class that can be easily converted to list/tensor.""" """Index class that can be easily converted to list/tensor."""
def __init__(self, data): def __init__(self, data, dtype="int64"):
assert dtype in ['int32', 'int64']
self.dtype = dtype
self._initialize_data(data) self._initialize_data(data)
def _initialize_data(self, data): def _initialize_data(self, data):
...@@ -43,18 +53,22 @@ class Index(object): ...@@ -43,18 +53,22 @@ class Index(object):
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
if F.is_tensor(data): if F.is_tensor(data):
if F.dtype(data) != F.int64: if F.dtype(data) != F.data_type_dict[self.dtype]:
raise DGLError('Index data must be an int64 vector, but got: %s' % str(data)) raise InconsistentDtypeException('Index data specified as %s, but got: %s' %
(self.dtype,
F.reverse_data_type_dict[F.dtype(data)]))
if len(F.shape(data)) > 1: if len(F.shape(data)) > 1:
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise InconsistentDtypeException('Index data must be 1D int32/int64 vector,\
but got shape: %s' % str(F.shape(data)))
if len(F.shape(data)) == 0: if len(F.shape(data)) == 0:
# a tensor of one int # a tensor of one int
self._dispatch(int(data)) self._dispatch(int(data))
else: else:
self._user_tensor_data[F.context(data)] = data self._user_tensor_data[F.context(data)] = data
elif isinstance(data, nd.NDArray): elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1): if not (data.dtype == self.dtype and len(data.shape) == 1):
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise InconsistentDtypeException('Index data must be 1D %s vector, but got: %s' %
(self.dtype, data.dtype))
self._dgl_tensor_data = data self._dgl_tensor_data = data
elif isinstance(data, slice): elif isinstance(data, slice):
# save it in the _pydata temporarily; materialize it if `tonumpy` is called # save it in the _pydata temporarily; materialize it if `tonumpy` is called
...@@ -63,7 +77,7 @@ class Index(object): ...@@ -63,7 +77,7 @@ class Index(object):
self._slice_data = slice(data.start, data.stop) self._slice_data = slice(data.start, data.stop)
else: else:
try: try:
data = np.asarray(data, dtype=np.int64) data = np.asarray(data, dtype=self.dtype)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
raise DGLError('Error index data: %s' % str(data)) raise DGLError('Error index data: %s' % str(data))
if data.ndim == 0: # scalar array if data.ndim == 0: # scalar array
...@@ -79,7 +93,7 @@ class Index(object): ...@@ -79,7 +93,7 @@ class Index(object):
if self._pydata is None: if self._pydata is None:
if self._slice_data is not None: if self._slice_data is not None:
slc = self._slice_data slc = self._slice_data
self._pydata = np.arange(slc.start, slc.stop).astype(np.int64) self._pydata = np.arange(slc.start, slc.stop).astype(self.dtype)
elif self._dgl_tensor_data is not None: elif self._dgl_tensor_data is not None:
self._pydata = self._dgl_tensor_data.asnumpy() self._pydata = self._dgl_tensor_data.asnumpy()
else: else:
...@@ -128,12 +142,13 @@ class Index(object): ...@@ -128,12 +142,13 @@ class Index(object):
def __getstate__(self): def __getstate__(self):
if self._slice_data is not None: if self._slice_data is not None:
# the index can be represented by a slice # the index can be represented by a slice
return self._slice_data return self._slice_data, self.dtype
else: else:
return self.tousertensor() return self.tousertensor(), self.dtype
def __setstate__(self, state): def __setstate__(self, state):
self._initialize_data(state) data, self.dtype = state
self._initialize_data(data)
def get_items(self, index): def get_items(self, index):
"""Return values at given positions of an Index """Return values at given positions of an Index
...@@ -155,18 +170,22 @@ class Index(object): ...@@ -155,18 +170,22 @@ class Index(object):
# the provided index is not a slice # the provided index is not a slice
tensor = self.tousertensor() tensor = self.tousertensor()
index = index.tousertensor() index = index.tousertensor()
return Index(F.gather_row(tensor, index)) # TODO(Allen): Change F.gather_row to dgl operation
return Index(F.gather_row(tensor, index), self.dtype)
elif self._slice_data is None: elif self._slice_data is None:
# the current index is not a slice but the provided is a slice # the current index is not a slice but the provided is a slice
tensor = self.tousertensor() tensor = self.tousertensor()
index = index._slice_data index = index._slice_data
return Index(F.narrow_row(tensor, index.start, index.stop)) # TODO(Allen): Change F.narrow_row to dgl operation
return Index(F.astype(F.narrow_row(tensor, index.start, index.stop),
F.data_type_dict[self.dtype]),
self.dtype)
else: else:
# both self and index wrap a slice object, then return another # both self and index wrap a slice object, then return another
# Index wrapping a slice # Index wrapping a slice
start = self._slice_data.start start = self._slice_data.start
index = index._slice_data index = index._slice_data
return Index(slice(start + index.start, start + index.stop)) return Index(slice(start + index.start, start + index.stop), self.dtype)
def set_items(self, index, value): def set_items(self, index, value):
"""Set values at given positions of an Index. Set is not done in place, """Set values at given positions of an Index. Set is not done in place,
...@@ -191,7 +210,7 @@ class Index(object): ...@@ -191,7 +210,7 @@ class Index(object):
value = F.full_1d(len(index), value, dtype=F.int64, ctx=F.cpu()) value = F.full_1d(len(index), value, dtype=F.int64, ctx=F.cpu())
else: else:
value = value.tousertensor() value = value.tousertensor()
return Index(F.scatter_row(tensor, index, value)) return Index(F.scatter_row(tensor, index, value), self.dtype)
def append_zeros(self, num): def append_zeros(self, num):
"""Append zeros to an Index """Append zeros to an Index
...@@ -205,24 +224,24 @@ class Index(object): ...@@ -205,24 +224,24 @@ class Index(object):
return self return self
new_items = F.zeros((num,), dtype=F.int64, ctx=F.cpu()) new_items = F.zeros((num,), dtype=F.int64, ctx=F.cpu())
if len(self) == 0: if len(self) == 0:
return Index(new_items) return Index(new_items, self.dtype)
else: else:
tensor = self.tousertensor() tensor = self.tousertensor()
tensor = F.cat((tensor, new_items), dim=0) tensor = F.cat((tensor, new_items), dim=0)
return Index(tensor) return Index(tensor, self.dtype)
def nonzero(self): def nonzero(self):
"""Return the nonzero positions""" """Return the nonzero positions"""
tensor = self.tousertensor() tensor = self.tousertensor()
mask = F.nonzero_1d(tensor != 0) mask = F.nonzero_1d(tensor != 0)
return Index(mask) return Index(mask, self.dtype)
def has_nonzero(self): def has_nonzero(self):
"""Check if there is any nonzero value in this Index""" """Check if there is any nonzero value in this Index"""
tensor = self.tousertensor() tensor = self.tousertensor()
return F.sum(tensor, 0) > 0 return F.sum(tensor, 0) > 0
def toindex(data): def toindex(data, dtype='int64'):
"""Convert the given data to Index object. """Convert the given data to Index object.
Parameters Parameters
...@@ -239,16 +258,17 @@ def toindex(data): ...@@ -239,16 +258,17 @@ def toindex(data):
-------- --------
Index Index
""" """
return data if isinstance(data, Index) else Index(data) return data if isinstance(data, Index) else Index(data, dtype)
def zero_index(size): def zero_index(size, dtype="int64"):
"""Create a index with provided size initialized to zero """Create a index with provided size initialized to zero
Parameters Parameters
---------- ----------
size: int size: int
""" """
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu())) return Index(F.zeros((size,), dtype=F.data_type_dict[dtype], ctx=F.cpu()),
dtype=dtype)
def set_diff(ar1, ar2): def set_diff(ar1, ar2):
"""Find the set difference of two index arrays. """Find the set difference of two index arrays.
......
...@@ -277,7 +277,8 @@ class HeteroNodeView(object): ...@@ -277,7 +277,8 @@ class HeteroNodeView(object):
def __call__(self, ntype=None): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
return F.arange(0, self._graph.number_of_nodes(ntype)) return F.arange(0, self._graph.number_of_nodes(ntype),
dtype=self._graph._idtype_str)
class HeteroNodeDataView(MutableMapping): class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called.""" """The data view class when G.ndata[ntype] is called."""
......
...@@ -115,7 +115,7 @@ template <DLDeviceType XPU, typename IdType> ...@@ -115,7 +115,7 @@ template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2) { IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]); CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0]; const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L); IdArray ret = NewIdArray(2 * L, DLContext{kDLCPU, 0}, arr1->dtype.bits);
const IdType* arr1_data = static_cast<IdType*>(arr1->data); const IdType* arr1_data = static_cast<IdType*>(arr1->data);
const IdType* arr2_data = static_cast<IdType*>(arr2->data); const IdType* arr2_data = static_cast<IdType*>(arr2->data);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
...@@ -173,7 +173,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -173,7 +173,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
} }
} }
// map array // map array
IdArray maparr = NewIdArray(newid); IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data); IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) { for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first; maparr_data[kv.second] = kv.first;
......
...@@ -10,6 +10,13 @@ ...@@ -10,6 +10,13 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "../../c_api_common.h"
#define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK(VAR1->dtype == VAR2->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \
<< (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype;
namespace dgl { namespace dgl {
......
...@@ -22,6 +22,7 @@ void CSRRemoveConsecutive( ...@@ -22,6 +22,7 @@ void CSRRemoveConsecutive(
std::vector<IdType> *new_indptr, std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices, std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) { std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -54,6 +55,7 @@ void CSRRemoveShuffled( ...@@ -54,6 +55,7 @@ void CSRRemoveShuffled(
std::vector<IdType> *new_indptr, std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices, std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) { std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data); const IdType *eid_data = static_cast<IdType *>(csr.data->data);
...@@ -77,6 +79,7 @@ void CSRRemoveShuffled( ...@@ -77,6 +79,7 @@ void CSRRemoveShuffled(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
if (n_entries == 0) if (n_entries == 0)
......
...@@ -43,6 +43,8 @@ template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t); ...@@ -43,6 +43,8 @@ template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
CHECK_SAME_DTYPE(csr.indices, row);
CHECK_SAME_DTYPE(csr.indices, col);
const auto rowlen = row->shape[0]; const auto rowlen = row->shape[0];
const auto collen = col->shape[0]; const auto collen = col->shape[0];
const auto rstlen = std::max(rowlen, collen); const auto rstlen = std::max(rowlen, collen);
...@@ -98,6 +100,7 @@ template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t); ...@@ -98,6 +100,7 @@ template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data); const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -194,6 +197,8 @@ template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t); ...@@ -194,6 +197,8 @@ template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -261,6 +266,8 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data ...@@ -261,6 +266,8 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) { std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
// TODO(minjie): more efficient implementation for matrix without duplicate entries // TODO(minjie): more efficient implementation for matrix without duplicate entries
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -448,6 +455,7 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t); ...@@ -448,6 +455,7 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
...@@ -494,6 +502,8 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray); ...@@ -494,6 +502,8 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
IdHashMap<IdType> hashmap(cols); IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0]; const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0]; const int64_t new_ncols = cols->shape[0];
......
...@@ -106,6 +106,7 @@ template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t); ...@@ -106,6 +106,7 @@ template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data); const IdType* vid_data = static_cast<IdType*>(rows->data);
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
...@@ -171,8 +172,10 @@ template NDArray COOGetData<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t); ...@@ -171,8 +172,10 @@ template NDArray COOGetData<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
///////////////////////////// COOGetDataAndIndices ///////////////////////////// ///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices( std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
COOMatrix coo, NDArray rows, NDArray cols) { NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows);
CHECK_SAME_DTYPE(coo.col, cols);
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
const int64_t len = std::max(rowlen, collen); const int64_t len = std::max(rowlen, collen);
......
...@@ -235,12 +235,34 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph( ...@@ -235,12 +235,34 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
} }
} }
FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const { HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
const std::vector<dgl_type_t>& etypes) const {
const int64_t bits = NumBits();
if (bits == 32) {
return FlattenImpl<int32_t>(etypes);
} else if (bits == 64) {
return FlattenImpl<int64_t>(etypes);
}
}
template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets; std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0; size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_id_t> result_src, result_dst; std::vector<IdType> result_src, result_dst;
std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype; std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid; std::vector<IdType> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set; std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this // XXXtype_offsets contain the mapping from node type and number of nodes after this
...@@ -261,7 +283,6 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp ...@@ -261,7 +283,6 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
dsttype_set.push_back(dsttype); dsttype_set.push_back(dsttype);
} }
} }
// Sort the node types so that we can compare the sets and decide whether a homograph // Sort the node types so that we can compare the sets and decide whether a homograph
// should be returned. // should be returned.
std::sort(srctype_set.begin(), srctype_set.end()); std::sort(srctype_set.begin(), srctype_set.end());
...@@ -301,9 +322,9 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp ...@@ -301,9 +322,9 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
EdgeArray edges = Edges(etype); EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype); size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data); const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data); const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data); const IdType* edges_eid_data = static_cast<const IdType*>(edges.id->data);
// TODO(gq) Use concat? // TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) { for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset); result_src.push_back(edges_src_data[i] + srctype_offset);
......
...@@ -202,6 +202,9 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -202,6 +202,9 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \return Save HeteroGraph to stream, using CSRMatrix */ /*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
...@@ -214,6 +217,15 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -214,6 +217,15 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief A map from vert type to the number of verts in the type */ /*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_; std::vector<int64_t> num_verts_per_type_;
/*! \brief template class for Flatten operation
*
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param etypes vector of etypes to be falttened
* \return pointer of FlattenedHeteroGraphh
*/
template <class IdType>
FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const;
}; };
} // namespace dgl } // namespace dgl
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
* \file graph/heterograph_capi.cc * \file graph/heterograph_capi.cc
* \brief Heterograph CAPI bindings. * \brief Heterograph CAPI bindings.
*/ */
#include "./heterograph.h" #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./heterograph.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -409,7 +411,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") ...@@ -409,7 +411,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int bits = args[1]; int bits = args[1];
HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits); HeteroGraphPtr bhg_ptr = hg.sptr();
auto hg_ptr = std::dynamic_pointer_cast<HeteroGraph>(bhg_ptr);
HeteroGraphPtr hg_new;
if (hg_ptr) {
hg_new = HeteroGraph::AsNumBits(hg_ptr, bits);
} else {
hg_new = UnitGraph::AsNumBits(bhg_ptr, bits);
}
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
...@@ -429,13 +438,22 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion") ...@@ -429,13 +438,22 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1]; List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
<< "Expect graph list has at least one graph";
std::vector<HeteroGraphPtr> component_ptrs; std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size()); component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
for (const auto& component : component_graphs) { for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr()); component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to batch have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
} }
auto hgptr = DisjointUnionHeteroGraph(meta_graph.sptr(), component_ptrs); ATEN_ID_BITS_SWITCH(bits, IdType, {
auto hgptr =
DisjointUnionHeteroGraph<IdType>(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
});
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
...@@ -443,8 +461,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") ...@@ -443,8 +461,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1]; const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2]; const IdArray edge_sizes = args[2];
const auto& ret = DisjointPartitionHeteroBySizes( const int64_t bits = hg->NumBits();
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes); std::vector<HeteroGraphPtr> ret;
ATEN_ID_BITS_SWITCH(bits, IdType, {
ret = DisjointPartitionHeteroBySizes<IdType>(hg->meta_graph(), hg.sptr(),
vertex_sizes, edge_sizes);
});
List<HeteroGraphRef> ret_list; List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) { for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr)); ret_list.push_back(HeteroGraphRef(hgptr));
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../unit_graph.h"
#include "randomwalk.h" #include "randomwalk.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -22,13 +24,22 @@ namespace { ...@@ -22,13 +24,22 @@ namespace {
/*! /*!
* \brief Random walk based on the given metapath. * \brief Random walk based on the given metapath.
* *
* \tparam IdType Index dtype of graph
* \param hg The heterograph * \param hg The heterograph
* \param etypes The metapath as an array of edge type IDs * \param etypes The metapath as an array of edge type IDs
* \param seeds The array of starting vertices for random walks * \param seeds The array of starting vertices for random walks
* \param num_traces Number of traces to generate for each starting vertex * \param num_traces Number of traces to generate for each starting vertex
* \note The metapath should have the same starting and ending node type. * \note The metapath should have the same starting and ending node type.
*/ */
template <typename T>
RandomWalkTracesPtr MetapathRandomWalk( RandomWalkTracesPtr MetapathRandomWalk(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces);
template <>
RandomWalkTracesPtr MetapathRandomWalk<int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray etypes, const IdArray etypes,
const IdArray seeds, const IdArray seeds,
...@@ -74,10 +85,64 @@ RandomWalkTracesPtr MetapathRandomWalk( ...@@ -74,10 +85,64 @@ RandomWalkTracesPtr MetapathRandomWalk(
return RandomWalkTracesPtr(tl); return RandomWalkTracesPtr(tl);
} }
/*!
* \brief This is a patch function for int32 HeteroGraph
* TODO: Refactor this with CSR and COO operations
*/
template <>
RandomWalkTracesPtr MetapathRandomWalk<int32_t>(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces) {
const auto metagraph = hg->meta_graph();
uint64_t num_etypes = etypes->shape[0];
uint64_t num_seeds = seeds->shape[0];
const dgl_type_t *etype_data = static_cast<dgl_type_t *>(etypes->data);
const int32_t *seed_data = static_cast<int32_t *>(seeds->data);
std::vector<int32_t> vertices;
std::vector<size_t> trace_lengths, trace_counts;
// TODO(quan): use omp to parallelize this loop
for (uint64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
int curr_num_traces = 0;
for (; curr_num_traces < num_traces; ++curr_num_traces) {
int32_t curr = seed_data[seed_id];
size_t trace_length = 0;
for (size_t i = 0; i < num_etypes; ++i) {
auto ug = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype_data[i]));
CHECK_NOTNULL(ug);
const auto &succ = ug->SuccVec32(etype_data[i], curr);
if (succ.size() == 0)
break;
curr = succ[RandomEngine::ThreadLocal()->RandInt(succ.size())];
vertices.push_back(curr);
++trace_length;
}
trace_lengths.push_back(trace_length);
}
trace_counts.push_back(curr_num_traces);
}
RandomWalkTraces *tl = new RandomWalkTraces;
tl->vertices = VecToIdArray(vertices);
tl->trace_lengths = VecToIdArray(trace_lengths);
tl->trace_counts = VecToIdArray(trace_counts);
return RandomWalkTracesPtr(tl);
}
}; // namespace }; // namespace
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk") DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef hg = args[0]; const HeteroGraphRef hg = args[0];
const IdArray etypes = args[1]; const IdArray etypes = args[1];
const IdArray seeds = args[2]; const IdArray seeds = args[2];
...@@ -89,7 +154,11 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk") ...@@ -89,7 +154,11 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
CHECK(aten::IsValidIdArray(seeds)); CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU) CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "MetapathRandomWalk only support CPU sampling"; << "MetapathRandomWalk only support CPU sampling";
const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces); const int64_t bits = hg->NumBits();
RandomWalkTracesPtr tl;
ATEN_ID_BITS_SWITCH(bits, IdType, {
tl = MetapathRandomWalk<IdType>(hg.sptr(), etypes, seeds, num_traces);
});
*rv = RandomWalkTracesRef(tl); *rv = RandomWalkTracesRef(tl);
}); });
......
...@@ -8,6 +8,7 @@ using namespace dgl::runtime; ...@@ -8,6 +8,7 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty"; CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
...@@ -19,16 +20,16 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -19,16 +20,16 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
auto pair = meta_graph->FindEdge(etype); auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
dgl_id_t src_offset = 0, dst_offset = 0; IdType src_offset = 0, dst_offset = 0;
std::vector<dgl_id_t> result_src, result_dst; std::vector<IdType> result_src, result_dst;
// Loop over all graphs // Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) { for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i]; const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype); EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype); size_t num_edges = cg->NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data); const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data); const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
// Loop over all edges // Loop over all edges
for (size_t j = 0; j < num_edges; ++j) { for (size_t j = 0; j < num_edges; ++j) {
...@@ -41,11 +42,9 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -41,11 +42,9 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
dst_offset += cg->NumVertices(dst_vtype); dst_offset += cg->NumVertices(dst_vtype);
} }
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO( HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2, (src_vtype == dst_vtype) ? 1 : 2, src_offset, dst_offset,
src_offset, aten::VecToIdArray(result_src, sizeof(IdType) * 8),
dst_offset, aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset; num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset; num_nodes_per_type[dst_vtype] = dst_offset;
...@@ -53,6 +52,13 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -53,6 +52,13 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
template HeteroGraphPtr DisjointUnionHeteroGraph<int32_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template HeteroGraphPtr DisjointUnionHeteroGraph<int64_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
// Sanity check for vertex sizes // Sanity check for vertex sizes
...@@ -102,11 +108,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -102,11 +108,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
EdgeArray edges = batched_graph->Edges(etype); EdgeArray edges = batched_graph->Edges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data); const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data); const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
// Loop over all graphs to be unbatched // Loop over all graphs to be unbatched
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<dgl_id_t> result_src, result_dst; std::vector<IdType> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type // Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) { for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) {
// TODO(mufei): Should use array operations to implement this. // TODO(mufei): Should use array operations to implement this.
...@@ -114,11 +120,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -114,11 +120,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]); result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
} }
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO( HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2, (src_vtype == dst_vtype) ? 1 : 2,
vertex_sizes_data[src_vtype * batch_size + g], vertex_sizes_data[src_vtype * batch_size + g],
vertex_sizes_data[dst_vtype * batch_size + g], vertex_sizes_data[dst_vtype * batch_size + g],
aten::VecToIdArray(result_src), aten::VecToIdArray(result_src, sizeof(IdType) * 8),
aten::VecToIdArray(result_dst)); aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
rel_graphs[g].push_back(rgptr); rel_graphs[g].push_back(rgptr);
} }
} }
...@@ -133,4 +139,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -133,4 +139,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
return rst; return rst;
} }
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
} // namespace dgl } // namespace dgl
...@@ -645,6 +645,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -645,6 +645,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override { DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context // TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later. // of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data); const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data); const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid]; const dgl_id_t start = indptr_data[vid];
...@@ -652,9 +653,20 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -652,9 +653,20 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return DGLIdIters(indices_data + start, indices_data + end); return DGLIdIters(indices_data + start, indices_data + end);
} }
DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
const int32_t start = indptr_data[vid];
const int32_t end = indptr_data[vid + 1];
return DGLIdIters32(indices_data + start, indices_data + end);
}
DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context // TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later. // of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data); const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data); const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
const dgl_id_t start = indptr_data[vid]; const dgl_id_t start = indptr_data[vid];
...@@ -951,6 +963,13 @@ DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const { ...@@ -951,6 +963,13 @@ DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
return ptr->SuccVec(etype, vid); return ptr->SuccVec(etype, vid);
} }
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
CHECK_NOTNULL(ptr);
return ptr->SuccVec32(etype, vid);
}
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
......
...@@ -139,6 +139,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -139,6 +139,9 @@ class UnitGraph : public BaseHeteroGraph {
DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override; DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;
// 32bit version functions, patch for SuccVec
DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;
DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override; DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;
DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override; DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;
......
...@@ -10,28 +10,29 @@ ...@@ -10,28 +10,29 @@
namespace dgl { namespace dgl {
namespace sched { namespace sched {
template <class IdType>
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids, std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids,
const IdArray& recv_ids) { const IdArray& recv_ids) {
auto n_msgs = msg_ids->shape[0]; auto n_msgs = msg_ids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const IdType* vid_data = static_cast<IdType*>(vids->data);
const int64_t* msg_id_data = static_cast<int64_t*>(msg_ids->data); const IdType* msg_id_data = static_cast<IdType*>(msg_ids->data);
const int64_t* recv_id_data = static_cast<int64_t*>(recv_ids->data); const IdType* recv_id_data = static_cast<IdType*>(recv_ids->data);
// in edge: dst->msgs // in edge: dst->msgs
std::unordered_map<int64_t, std::vector<int64_t>> in_edges; std::unordered_map<IdType, std::vector<IdType>> in_edges;
for (int64_t i = 0; i < n_msgs; ++i) { for (IdType i = 0; i < n_msgs; ++i) {
in_edges[vid_data[i]].push_back(msg_id_data[i]); in_edges[vid_data[i]].push_back(msg_id_data[i]);
} }
// bkt: deg->dsts // bkt: deg->dsts
std::unordered_map<int64_t, std::vector<int64_t>> bkt; std::unordered_map<IdType, std::vector<IdType>> bkt;
for (const auto& it : in_edges) { for (const auto& it : in_edges) {
bkt[it.second.size()].push_back(it.first); bkt[it.second.size()].push_back(it.first);
} }
std::unordered_set<int64_t> zero_deg_nodes; std::unordered_set<IdType> zero_deg_nodes;
for (int64_t i = 0; i < recv_ids->shape[0]; ++i) { for (IdType i = 0; i < recv_ids->shape[0]; ++i) {
if (in_edges.find(recv_id_data[i]) == in_edges.end()) { if (in_edges.find(recv_id_data[i]) == in_edges.end()) {
zero_deg_nodes.insert(recv_id_data[i]); zero_deg_nodes.insert(recv_id_data[i]);
} }
...@@ -39,9 +40,9 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -39,9 +40,9 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
auto n_zero_deg = zero_deg_nodes.size(); auto n_zero_deg = zero_deg_nodes.size();
// calc output size // calc output size
int64_t n_deg = bkt.size(); IdType n_deg = bkt.size();
int64_t n_dst = in_edges.size(); IdType n_dst = in_edges.size();
int64_t n_mid_sec = bkt.size(); // zero deg won't affect message size IdType n_mid_sec = bkt.size(); // zero deg won't affect message size
if (n_zero_deg > 0) { if (n_zero_deg > 0) {
n_deg += 1; n_deg += 1;
n_dst += n_zero_deg; n_dst += n_zero_deg;
...@@ -53,16 +54,16 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -53,16 +54,16 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
IdArray nid_section = IdArray::Empty({n_deg}, vids->dtype, vids->ctx); IdArray nid_section = IdArray::Empty({n_deg}, vids->dtype, vids->ctx);
IdArray mids = IdArray::Empty({n_msgs}, vids->dtype, vids->ctx); IdArray mids = IdArray::Empty({n_msgs}, vids->dtype, vids->ctx);
IdArray mid_section = IdArray::Empty({n_mid_sec}, vids->dtype, vids->ctx); IdArray mid_section = IdArray::Empty({n_mid_sec}, vids->dtype, vids->ctx);
int64_t* deg_ptr = static_cast<int64_t*>(degs->data); IdType* deg_ptr = static_cast<IdType*>(degs->data);
int64_t* nid_ptr = static_cast<int64_t*>(nids->data); IdType* nid_ptr = static_cast<IdType*>(nids->data);
int64_t* nsec_ptr = static_cast<int64_t*>(nid_section->data); IdType* nsec_ptr = static_cast<IdType*>(nid_section->data);
int64_t* mid_ptr = static_cast<int64_t*>(mids->data); IdType* mid_ptr = static_cast<IdType*>(mids->data);
int64_t* msec_ptr = static_cast<int64_t*>(mid_section->data); IdType* msec_ptr = static_cast<IdType*>(mid_section->data);
// fill in bucketing ordering // fill in bucketing ordering
for (const auto& it : bkt) { // for each bucket for (const auto& it : bkt) { // for each bucket
const int64_t deg = it.first; const IdType deg = it.first;
const int64_t bucket_size = it.second.size(); const IdType bucket_size = it.second.size();
*deg_ptr++ = deg; *deg_ptr++ = deg;
*nsec_ptr++ = bucket_size; *nsec_ptr++ = bucket_size;
*msec_ptr++ = deg * bucket_size; *msec_ptr++ = deg * bucket_size;
...@@ -92,28 +93,37 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -92,28 +93,37 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
return std::move(ret); return std::move(ret);
} }
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& vids, template std::vector<IdArray> DegreeBucketing<int32_t>(const IdArray& msg_ids,
const IdArray& vids,
const IdArray& recv_ids);
template std::vector<IdArray> DegreeBucketing<int64_t>(const IdArray& msg_ids,
const IdArray& vids,
const IdArray& recv_ids);
template <class IdType>
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids,
const IdArray& vids,
const IdArray& eids) { const IdArray& eids) {
auto n_edge = eids->shape[0]; auto n_edge = eids->shape[0];
const int64_t* eid_data = static_cast<int64_t*>(eids->data); const IdType* eid_data = static_cast<IdType*>(eids->data);
const int64_t* uid_data = static_cast<int64_t*>(uids->data); const IdType* uid_data = static_cast<IdType*>(uids->data);
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const IdType* vid_data = static_cast<IdType*>(vids->data);
// node2edge: group_by nodes uid -> (eid, the other end vid) // node2edge: group_by nodes uid -> (eid, the other end vid)
std::unordered_map<int64_t, std::unordered_map<IdType, std::vector<std::pair<IdType, IdType>>> node2edge;
std::vector<std::pair<int64_t, int64_t>>> node2edge; for (IdType i = 0; i < n_edge; ++i) {
for (int64_t i = 0; i < n_edge; ++i) {
node2edge[uid_data[i]].emplace_back(eid_data[i], vid_data[i]); node2edge[uid_data[i]].emplace_back(eid_data[i], vid_data[i]);
} }
// bkt: deg -> group_by node uid // bkt: deg -> group_by node uid
std::unordered_map<int64_t, std::vector<int64_t>> bkt; std::unordered_map<IdType, std::vector<IdType>> bkt;
for (const auto& it : node2edge) { for (const auto& it : node2edge) {
bkt[it.second.size()].push_back(it.first); bkt[it.second.size()].push_back(it.first);
} }
// number of unique degree // number of unique degree
int64_t n_deg = bkt.size(); IdType n_deg = bkt.size();
// initialize output // initialize output
IdArray degs = IdArray::Empty({n_deg}, eids->dtype, eids->ctx); IdArray degs = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
...@@ -121,18 +131,18 @@ std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& v ...@@ -121,18 +131,18 @@ std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& v
IdArray new_vids = IdArray::Empty({n_edge}, vids->dtype, vids->ctx); IdArray new_vids = IdArray::Empty({n_edge}, vids->dtype, vids->ctx);
IdArray new_eids = IdArray::Empty({n_edge}, eids->dtype, eids->ctx); IdArray new_eids = IdArray::Empty({n_edge}, eids->dtype, eids->ctx);
IdArray sections = IdArray::Empty({n_deg}, eids->dtype, eids->ctx); IdArray sections = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
int64_t* deg_ptr = static_cast<int64_t*>(degs->data); IdType* deg_ptr = static_cast<IdType*>(degs->data);
int64_t* uid_ptr = static_cast<int64_t*>(new_uids->data); IdType* uid_ptr = static_cast<IdType*>(new_uids->data);
int64_t* vid_ptr = static_cast<int64_t*>(new_vids->data); IdType* vid_ptr = static_cast<IdType*>(new_vids->data);
int64_t* eid_ptr = static_cast<int64_t*>(new_eids->data); IdType* eid_ptr = static_cast<IdType*>(new_eids->data);
int64_t* sec_ptr = static_cast<int64_t*>(sections->data); IdType* sec_ptr = static_cast<IdType*>(sections->data);
// fill in bucketing ordering // fill in bucketing ordering
for (const auto& it : bkt) { // for each bucket for (const auto& it : bkt) { // for each bucket
// degree of this bucket // degree of this bucket
const int64_t deg = it.first; const IdType deg = it.first;
// number of edges in this bucket // number of edges in this bucket
const int64_t bucket_size = it.second.size(); const IdType bucket_size = it.second.size();
*deg_ptr++ = deg; *deg_ptr++ = deg;
*sec_ptr++ = deg * bucket_size; *sec_ptr++ = deg * bucket_size;
for (const auto u : it.second) { // for uid in this bucket for (const auto u : it.second) { // for uid in this bucket
...@@ -154,6 +164,12 @@ std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& v ...@@ -154,6 +164,12 @@ std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& v
return std::move(ret); return std::move(ret);
} }
template std::vector<IdArray> GroupEdgeByNodeDegree<int32_t>(
const IdArray& uids, const IdArray& vids, const IdArray& eids);
template std::vector<IdArray> GroupEdgeByNodeDegree<int64_t>(
const IdArray& uids, const IdArray& vids, const IdArray& eids);
} // namespace sched } // namespace sched
} // namespace dgl } // namespace dgl
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* \file scheduler/scheduler_apis.cc * \file scheduler/scheduler_apis.cc
* \brief DGL scheduler APIs * \brief DGL scheduler APIs
*/ */
#include <dgl/array.h>
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/scheduler.h> #include <dgl/scheduler.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../array/cpu/array_utils.h"
using dgl::runtime::DGLArgs; using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue; using dgl::runtime::DGLRetValue;
...@@ -14,11 +16,16 @@ using dgl::runtime::NDArray; ...@@ -14,11 +16,16 @@ using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing") DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray msg_ids = args[0]; const IdArray msg_ids = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
const IdArray nids = args[2]; const IdArray nids = args[2];
*rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(msg_ids, vids, nids)); CHECK_SAME_DTYPE(msg_ids, vids);
CHECK_SAME_DTYPE(msg_ids, nids);
ATEN_ID_TYPE_SWITCH(msg_ids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc(
sched::DegreeBucketing<IdType>(msg_ids, vids, nids));
});
}); });
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree") DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
...@@ -26,8 +33,12 @@ DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree") ...@@ -26,8 +33,12 @@ DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
const IdArray uids = args[0]; const IdArray uids = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
const IdArray eids = args[2]; const IdArray eids = args[2];
CHECK_SAME_DTYPE(uids, vids);
CHECK_SAME_DTYPE(uids, eids);
ATEN_ID_TYPE_SWITCH(uids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc( *rv = ConvertNDArrayVectorToPackedFunc(
sched::GroupEdgeByNodeDegree(uids, vids, eids)); sched::GroupEdgeByNodeDegree<IdType>(uids, vids, eids));
});
}); });
} // namespace dgl } // namespace dgl
...@@ -2,6 +2,7 @@ import dgl ...@@ -2,6 +2,7 @@ import dgl
import backend as F import backend as F
from dgl.base import ALL from dgl.base import ALL
from utils import parametrize_dtype
def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None): def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes assert g1.ntypes == g2.ntypes
...@@ -32,18 +33,19 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N ...@@ -32,18 +33,19 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
for feat_name in edge_attrs[ety]: for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]) assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
def test_batching_hetero_topology(): @parametrize_dtype
def test_batching_hetero_topology(index_dtype):
"""Test batching two DGLHeteroGraphs where some nodes are isolated in some relations""" """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
g1 = dgl.heterograph({ g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)] ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
}) }, index_dtype=index_dtype)
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)] ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
}) }, index_dtype=index_dtype)
bg = dgl.batch_hetero([g1, g2]) bg = dgl.batch_hetero([g1, g2])
assert bg.ntypes == g2.ntypes assert bg.ntypes == g2.ntypes
...@@ -90,21 +92,23 @@ def test_batching_hetero_topology(): ...@@ -90,21 +92,23 @@ def test_batching_hetero_topology():
check_equivalence_between_heterographs(g1, g3) check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4) check_equivalence_between_heterographs(g2, g4)
def test_batching_hetero_and_batched_hetero_topology():
@parametrize_dtype
def test_batching_hetero_and_batched_hetero_topology(index_dtype):
"""Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph.""" """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
g1 = dgl.heterograph({ g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)] ('user', 'plays', 'game'): [(0, 0), (1, 0)]
}) }, index_dtype=index_dtype)
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)] ('user', 'plays', 'game'): [(0, 0), (1, 0)]
}) }, index_dtype=index_dtype)
bg1 = dgl.batch_hetero([g1, g2]) bg1 = dgl.batch_hetero([g1, g2])
g3 = dgl.heterograph({ g3 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1)], ('user', 'follows', 'user'): [(0, 1)],
('user', 'plays', 'game'): [(1, 0)] ('user', 'plays', 'game'): [(1, 0)]
}) }, index_dtype=index_dtype)
bg2 = dgl.batch_hetero([bg1, g3]) bg2 = dgl.batch_hetero([bg1, g3])
assert bg2.ntypes == g3.ntypes assert bg2.ntypes == g3.ntypes
assert bg2.etypes == g3.etypes assert bg2.etypes == g3.etypes
...@@ -149,12 +153,13 @@ def test_batching_hetero_and_batched_hetero_topology(): ...@@ -149,12 +153,13 @@ def test_batching_hetero_and_batched_hetero_topology():
check_equivalence_between_heterographs(g2, g5) check_equivalence_between_heterographs(g2, g5)
check_equivalence_between_heterographs(g3, g6) check_equivalence_between_heterographs(g3, g6)
def test_batched_features(): @parametrize_dtype
def test_batched_features(index_dtype):
"""Test the features of batched DGLHeteroGraphs""" """Test the features of batched DGLHeteroGraphs"""
g1 = dgl.heterograph({ g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)] ('user', 'plays', 'game'): [(0, 0), (1, 0)]
}) }, index_dtype=index_dtype)
g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g1.nodes['game'].data['h1'] = F.tensor([[0.]]) g1.nodes['game'].data['h1'] = F.tensor([[0.]])
...@@ -166,7 +171,7 @@ def test_batched_features(): ...@@ -166,7 +171,7 @@ def test_batched_features():
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)] ('user', 'plays', 'game'): [(0, 0), (1, 0)]
}) }, index_dtype=index_dtype)
g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]])
......
...@@ -6,6 +6,9 @@ import dgl ...@@ -6,6 +6,9 @@ import dgl
import networkx as nx import networkx as nx
from collections import defaultdict as ddict from collections import defaultdict as ddict
import unittest import unittest
import pytest
import inspect
from utils import parametrize_dtype
D = 5 D = 5
reduce_msg_shapes = set() reduce_msg_shapes = set()
...@@ -25,7 +28,7 @@ def reduce_func(nodes): ...@@ -25,7 +28,7 @@ def reduce_func(nodes):
def apply_node_func(nodes): def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['accum']} return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False): def generate_graph(index_dtype='int64', grad=False):
''' '''
s, d, eid s, d, eid
0, 1, 0 0, 1, 0
...@@ -47,7 +50,7 @@ def generate_graph(grad=False): ...@@ -47,7 +50,7 @@ def generate_graph(grad=False):
9, 0, 16 9, 0, 16
''' '''
g = dgl.graph([(0,1), (1,9), (0,2), (2,9), (0,3), (3,9), (0,4), (4,9), g = dgl.graph([(0,1), (1,9), (0,2), (2,9), (0,3), (3,9), (0,4), (4,9),
(0,5), (5,9), (0,6), (6,9), (0,7), (7,9), (0,8), (8,9), (9,0)]) (0,5), (5,9), (0,6), (6,9), (0,7), (7,9), (0,8), (8,9), (9,0)], index_dtype=index_dtype)
ncol = F.randn((10, D)) ncol = F.randn((10, D))
ecol = F.randn((17, D)) ecol = F.randn((17, D))
if grad: if grad:
...@@ -60,27 +63,35 @@ def generate_graph(grad=False): ...@@ -60,27 +63,35 @@ def generate_graph(grad=False):
g.set_e_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer)
return g return g
def test_isolated_nodes():
g = dgl.graph([(0, 1), (1, 2)], num_nodes=5) @parametrize_dtype
def test_isolated_nodes(index_dtype):
g = dgl.graph([(0, 1), (1, 2)], num_nodes=5, index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
# Test backward compatibility # Test backward compatibility
g = dgl.graph([(0, 1), (1, 2)], card=5) g = dgl.graph([(0, 1), (1, 2)], card=5, index_dtype=index_dtype)
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', num_nodes=(5, 7)) g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays',
'game', num_nodes=(5, 7), index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes('user') == 5 assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7 assert g.number_of_nodes('game') == 7
# Test backward compatibility # Test backward compatibility
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', card=(5, 7)) g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays',
'game', card=(5, 7), index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes('user') == 5 assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7 assert g.number_of_nodes('game') == 7
def test_batch_setter_getter(): @parametrize_dtype
def test_batch_setter_getter(index_dtype):
def _pfc(x): def _pfc(x):
return list(F.zerocopy_to_numpy(x)[:,0]) return list(F.zerocopy_to_numpy(x)[:,0])
g = generate_graph() g = generate_graph(index_dtype)
# set all nodes # set all nodes
g.ndata['h'] = F.zeros((10, D)) g.ndata['h'] = F.zeros((10, D))
assert F.allclose(g.ndata['h'], F.zeros((10, D))) assert F.allclose(g.ndata['h'], F.zeros((10, D)))
...@@ -90,11 +101,11 @@ def test_batch_setter_getter(): ...@@ -90,11 +101,11 @@ def test_batch_setter_getter():
assert len(g.ndata) == old_len - 1 assert len(g.ndata) == old_len - 1
g.ndata['h'] = F.zeros((10, D)) g.ndata['h'] = F.zeros((10, D))
# set partial nodes # set partial nodes
u = F.tensor([1, 3, 5]) u = F.tensor([1, 3, 5], F.data_type_dict[index_dtype])
g.nodes[u].data['h'] = F.ones((3, D)) g.nodes[u].data['h'] = F.ones((3, D))
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = F.tensor([1, 2, 3]) u = F.tensor([1, 2, 3], F.data_type_dict[index_dtype])
assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.] assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
''' '''
...@@ -126,42 +137,44 @@ def test_batch_setter_getter(): ...@@ -126,42 +137,44 @@ def test_batch_setter_getter():
assert len(g.edata) == old_len - 1 assert len(g.edata) == old_len - 1
g.edata['l'] = F.zeros((17, D)) g.edata['l'] = F.zeros((17, D))
# set partial edges (many-many) # set partial edges (many-many)
u = F.tensor([0, 0, 2, 5, 9]) u = F.tensor([0, 0, 2, 5, 9], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 3, 9, 9, 0]) v = F.tensor([1, 3, 9, 9, 0], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((5, D)) g.edges[u, v].data['l'] = F.ones((5, D))
truth = [0.] * 17 truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1. truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.edata['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (many-one) # set partial edges (many-one)
u = F.tensor([3, 4, 6]) u = F.tensor([3, 4, 6], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9]) v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((3, D)) g.edges[u, v].data['l'] = F.ones((3, D))
truth[5] = truth[7] = truth[11] = 1. truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.edata['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (one-many) # set partial edges (one-many)
u = F.tensor([0]) u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([4, 5, 6]) v = F.tensor([4, 5, 6], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((3, D)) g.edges[u, v].data['l'] = F.ones((3, D))
truth[6] = truth[8] = truth[10] = 1. truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.edata['l']) == truth assert _pfc(g.edata['l']) == truth
# get partial edges (many-many) # get partial edges (many-many)
u = F.tensor([0, 6, 0]) u = F.tensor([0, 6, 0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([6, 9, 7]) v = F.tensor([6, 9, 7], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one) # get partial edges (many-one)
u = F.tensor([5, 6, 7]) u = F.tensor([5, 6, 7], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9]) v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many) # get partial edges (one-many)
u = F.tensor([0]) u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([3, 4, 5]) v = F.tensor([3, 4, 5], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True) @parametrize_dtype
def test_batch_setter_autograd(index_dtype):
g = generate_graph(index_dtype=index_dtype, grad=True)
h1 = g.ndata['h'] h1 = g.ndata['h']
# partial set # partial set
v = F.tensor([1, 2, 8]) v = F.tensor([1, 2, 8], F.data_type_dict[index_dtype])
hh = F.attach_grad(F.zeros((len(v), D))) hh = F.attach_grad(F.zeros((len(v), D)))
with F.record_grad(): with F.record_grad():
g.nodes[v].data['h'] = hh g.nodes[v].data['h'] = hh
...@@ -170,7 +183,9 @@ def test_batch_setter_autograd(): ...@@ -170,7 +183,9 @@ def test_batch_setter_autograd():
assert F.array_equal(F.grad(h1)[:,0], F.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) assert F.array_equal(F.grad(h1)[:,0], F.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
assert F.array_equal(F.grad(hh)[:,0], F.tensor([2., 2., 2.])) assert F.array_equal(F.grad(hh)[:,0], F.tensor([2., 2., 2.]))
def test_nx_conversion():
@parametrize_dtype
def atest_nx_conversion(index_dtype):
# check conversion between networkx and DGLGraph # check conversion between networkx and DGLGraph
def _check_nx_feature(nxg, nf, ef): def _check_nx_feature(nxg, nf, ef):
...@@ -207,7 +222,7 @@ def test_nx_conversion(): ...@@ -207,7 +222,7 @@ def test_nx_conversion():
n3 = F.randn((5, 4)) n3 = F.randn((5, 4))
e1 = F.randn((4, 5)) e1 = F.randn((4, 5))
e2 = F.randn((4, 7)) e2 = F.randn((4, 7))
g = dgl.graph([(0,2),(1,4),(3,0),(4,3)]) g = dgl.graph([(0,2),(1,4),(3,0),(4,3)], index_dtype=index_dtype)
g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3}) g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
g.edata.update({'e1': e1, 'e2': e2}) g.edata.update({'e1': e1, 'e2': e2})
...@@ -219,7 +234,8 @@ def test_nx_conversion(): ...@@ -219,7 +234,8 @@ def test_nx_conversion():
# convert to DGLGraph, nx graph has id in edge feature # convert to DGLGraph, nx graph has id in edge feature
# use id feature to test non-tensor copy # use id feature to test non-tensor copy
g = dgl.graph(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id']) g = dgl.graph(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id'], index_dtype=index_dtype)
assert g._idtype_str == index_dtype
# check graph size # check graph size
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4 assert g.number_of_edges() == 4
...@@ -289,61 +305,67 @@ def test_nx_conversion(): ...@@ -289,61 +305,67 @@ def test_nx_conversion():
assert F.allclose(g.edata['h'], F.tensor([[1., 2.], [1., 2.], assert F.allclose(g.edata['h'], F.tensor([[1., 2.], [1., 2.],
[2., 3.], [2., 3.]])) [2., 3.], [2., 3.]]))
def test_batch_send(): @parametrize_dtype
g = generate_graph() def test_batch_send(index_dtype):
g = generate_graph(index_dtype=index_dtype)
def _fmsg(edges): def _fmsg(edges):
assert tuple(F.shape(edges.src['h'])) == (5, D) assert tuple(F.shape(edges.src['h'])) == (5, D)
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
# many-many send # many-many send
u = F.tensor([0, 0, 0, 0, 0]) u = F.tensor([0, 0, 0, 0, 0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 4, 5]) v = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg) g.send((u, v), _fmsg)
# one-many send # one-many send
u = F.tensor([0]) u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 4, 5]) v = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg) g.send((u, v), _fmsg)
# many-one send # many-one send
u = F.tensor([1, 2, 3, 4, 5]) u = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9]) v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg) g.send((u, v), _fmsg)
def test_batch_recv(): @parametrize_dtype
def test_batch_recv(index_dtype):
# basic recv test # basic recv test
g = generate_graph() g = generate_graph(index_dtype=index_dtype)
u = F.tensor([0, 0, 0, 4, 5, 6]) u = F.tensor([0, 0, 0, 4, 5, 6], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 9, 9, 9]) v = F.tensor([1, 2, 3, 9, 9, 9], dtype=F.data_type_dict[index_dtype])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.send((u, v), message_func) g.send((u, v), message_func)
g.recv(F.unique(v), reduce_func, apply_node_func) g.recv(F.astype(F.unique(v), F.data_type_dict[index_dtype]), reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_apply_nodes():
@parametrize_dtype
def test_apply_nodes(index_dtype):
def _upd(nodes): def _upd(nodes):
return {'h' : nodes.data['h'] * 2} return {'h' : nodes.data['h'] * 2}
g = generate_graph() g = generate_graph(index_dtype=index_dtype)
old = g.ndata['h'] old = g.ndata['h']
g.apply_nodes(_upd) g.apply_nodes(_upd)
assert F.allclose(old * 2, g.ndata['h']) assert F.allclose(old * 2, g.ndata['h'])
u = F.tensor([0, 3, 4, 6]) u = F.tensor([0, 3, 4, 6], F.data_type_dict[index_dtype])
g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u) g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
assert F.allclose(F.gather_row(g.ndata['h'], u), F.zeros((4, D))) assert F.allclose(F.gather_row(g.ndata['h'], u), F.zeros((4, D)))
def test_apply_edges(): @parametrize_dtype
def test_apply_edges(index_dtype):
def _upd(edges): def _upd(edges):
return {'w' : edges.data['w'] * 2} return {'w' : edges.data['w'] * 2}
g = generate_graph() g = generate_graph(index_dtype=index_dtype)
old = g.edata['w'] old = g.edata['w']
g.apply_edges(_upd) g.apply_edges(_upd)
assert F.allclose(old * 2, g.edata['w']) assert F.allclose(old * 2, g.edata['w'])
u = F.tensor([0, 0, 0, 4, 5, 6]) u = F.tensor([0, 0, 0, 4, 5, 6], F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 9, 9, 9]) v = F.tensor([1, 2, 3, 9, 9, 9], F.data_type_dict[index_dtype])
g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v)) g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
eid = F.tensor(g.edge_ids(u, v)) eid = F.tensor(g.edge_ids(u, v), F.data_type_dict[index_dtype])
assert F.allclose(F.gather_row(g.edata['w'], eid), F.zeros((6, D))) assert F.allclose(F.gather_row(g.edata['w'], eid), F.zeros((6, D)))
def test_update_routines(): @parametrize_dtype
g = generate_graph() def test_update_routines(index_dtype):
g = generate_graph(index_dtype=index_dtype)
# send_and_recv # send_and_recv
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -359,14 +381,14 @@ def test_update_routines(): ...@@ -359,14 +381,14 @@ def test_update_routines():
pass pass
# pull # pull
v = F.tensor([1, 2, 3, 9]) v = F.tensor([1, 2, 3, 9], F.data_type_dict[index_dtype])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.pull(v, message_func, reduce_func, apply_node_func) g.pull(v, message_func, reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
# push # push
v = F.tensor([0, 1, 2, 3]) v = F.tensor([0, 1, 2, 3], F.data_type_dict[index_dtype])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.push(v, message_func, reduce_func, apply_node_func) g.push(v, message_func, reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
...@@ -378,9 +400,10 @@ def test_update_routines(): ...@@ -378,9 +400,10 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)}) assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_recv_0deg(): @parametrize_dtype
def test_recv_0deg(index_dtype):
# test recv with 0deg nodes; # test recv with 0deg nodes;
g = dgl.graph([(0,1)]) g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges): def _message(edges):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
...@@ -412,9 +435,11 @@ def test_recv_0deg(): ...@@ -412,9 +435,11 @@ def test_recv_0deg():
# non-0deg check: untouched # non-0deg check: untouched
assert F.allclose(new[1], old[1]) assert F.allclose(new[1], old[1])
def test_recv_0deg_newfld():
@parametrize_dtype
def test_recv_0deg_newfld(index_dtype):
# test recv with 0deg nodes; the reducer also creates a new field # test recv with 0deg nodes; the reducer also creates a new field
g = dgl.graph([(0,1)]) g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges): def _message(edges):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
...@@ -447,9 +472,10 @@ def test_recv_0deg_newfld(): ...@@ -447,9 +472,10 @@ def test_recv_0deg_newfld():
# non-0deg check: not changed # non-0deg check: not changed
assert F.allclose(new[1], F.full_1d(5, -1, F.int64)) assert F.allclose(new[1], F.full_1d(5, -1, F.int64))
def test_update_all_0deg(): @parametrize_dtype
def test_update_all_0deg(index_dtype):
# test#1 # test#1
g = dgl.graph([(1,0), (2,0), (3,0), (4,0)]) g = dgl.graph([(1,0), (2,0), (3,0), (4,0)], index_dtype=index_dtype)
def _message(edges): def _message(edges):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
...@@ -470,7 +496,7 @@ def test_update_all_0deg(): ...@@ -470,7 +496,7 @@ def test_update_all_0deg():
assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0)) assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0))
# test#2: # test#2:
g = dgl.graph([], num_nodes=5) g = dgl.graph([], num_nodes=5, index_dtype=index_dtype)
g.set_n_initializer(_init2, 'h') g.set_n_initializer(_init2, 'h')
g.ndata['h'] = old_repr g.ndata['h'] = old_repr
g.update_all(_message, _reduce, _apply) g.update_all(_message, _reduce, _apply)
...@@ -478,8 +504,9 @@ def test_update_all_0deg(): ...@@ -478,8 +504,9 @@ def test_update_all_0deg():
# should fallback to apply # should fallback to apply
assert F.allclose(new_repr, 2*old_repr) assert F.allclose(new_repr, 2*old_repr)
def test_pull_0deg(): @parametrize_dtype
g = dgl.graph([(0,1)]) def test_pull_0deg(index_dtype):
g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges): def _message(edges):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
...@@ -509,8 +536,9 @@ def test_pull_0deg(): ...@@ -509,8 +536,9 @@ def test_pull_0deg():
# non-0deg check: not touched # non-0deg check: not touched
assert F.allclose(new[1], old[1]) assert F.allclose(new[1], old[1])
def test_send_multigraph(): @parametrize_dtype
g = dgl.graph([(0,1), (0,1), (0,1), (2,1)]) def test_send_multigraph(index_dtype):
g = dgl.graph([(0,1), (0,1), (0,1), (2,1)], index_dtype=index_dtype)
def _message_a(edges): def _message_a(edges):
return {'a': edges.data['a']} return {'a': edges.data['a']}
...@@ -607,9 +635,9 @@ def _test_dynamic_addition(): ...@@ -607,9 +635,9 @@ def _test_dynamic_addition():
g.add_edge(2, 1, {'h1': F.randn((1, D))}) g.add_edge(2, 1, {'h1': F.randn((1, D))})
assert len(g.edata['h1']) == len(g.edata['h2']) assert len(g.edata['h1']) == len(g.edata['h2'])
@parametrize_dtype
def test_repr(): def test_repr(index_dtype):
G = dgl.graph([(0,1), (0,2), (1,2)], num_nodes=10) G = dgl.graph([(0,1), (0,2), (1,2)], num_nodes=10, index_dtype=index_dtype)
repr_string = G.__repr__() repr_string = G.__repr__()
print(repr_string) print(repr_string)
G.ndata['x'] = F.zeros((10, 5)) G.ndata['x'] = F.zeros((10, 5))
...@@ -618,7 +646,8 @@ def test_repr(): ...@@ -618,7 +646,8 @@ def test_repr():
print(repr_string) print(repr_string)
def test_group_apply_edges(): @parametrize_dtype
def test_group_apply_edges(index_dtype):
def edge_udf(edges): def edge_udf(edges):
h = F.sum(edges.data['feat'] * (edges.src['h'] + edges.dst['h']), dim=2) h = F.sum(edges.data['feat'] * (edges.src['h'] + edges.dst['h']), dim=2)
normalized_feat = F.softmax(h, dim=1) normalized_feat = F.softmax(h, dim=1)
...@@ -631,7 +660,7 @@ def test_group_apply_edges(): ...@@ -631,7 +660,7 @@ def test_group_apply_edges():
elist.append((1, v)) elist.append((1, v))
for v in [2, 3, 4, 5, 6, 7, 8]: for v in [2, 3, 4, 5, 6, 7, 8]:
elist.append((2, v)) elist.append((2, v))
g = dgl.graph(elist) g = dgl.graph(elist, index_dtype=index_dtype)
g.ndata['h'] = F.randn((g.number_of_nodes(), D)) g.ndata['h'] = F.randn((g.number_of_nodes(), D))
g.edata['feat'] = F.randn((g.number_of_edges(), D)) g.edata['feat'] = F.randn((g.number_of_edges(), D))
...@@ -653,8 +682,9 @@ def test_group_apply_edges(): ...@@ -653,8 +682,9 @@ def test_group_apply_edges():
# test group by destination nodes # test group by destination nodes
_test('dst') _test('dst')
def test_local_var(): @parametrize_dtype
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)]) def test_local_var(index_dtype):
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)], index_dtype=index_dtype)
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3)) g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4)) g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override # test override
...@@ -710,8 +740,9 @@ def test_local_var(): ...@@ -710,8 +740,9 @@ def test_local_var():
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]])) assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g) foo(g)
def test_local_scope(): @parametrize_dtype
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)]) def test_local_scope(index_dtype):
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)], index_dtype=index_dtype)
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3)) g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4)) g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override # test override
...@@ -762,7 +793,7 @@ def test_local_scope(): ...@@ -762,7 +793,7 @@ def test_local_scope():
assert 'ww' not in g.edata assert 'ww' not in g.edata
# test initializer1 # test initializer1
g = dgl.graph([(0,1), (1,1)]) g = dgl.graph([(0,1), (1,1)], index_dtype=index_dtype)
g.set_n_initializer(dgl.init.zero_initializer) g.set_n_initializer(dgl.init.zero_initializer)
def foo(g): def foo(g):
with g.local_scope(): with g.local_scope():
...@@ -781,32 +812,35 @@ def test_local_scope(): ...@@ -781,32 +812,35 @@ def test_local_scope():
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]])) assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g) foo(g)
def test_issue_1088(): @parametrize_dtype
def test_issue_1088(index_dtype):
# This test ensures that message passing on a heterograph with one edge type # This test ensures that message passing on a heterograph with one edge type
# would not crash (GitHub issue #1088). # would not crash (GitHub issue #1088).
import dgl.function as fn import dgl.function as fn
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])}) g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])}, index_dtype=index_dtype)
g.nodes['U'].data['x'] = F.randn((3, 3)) g.nodes['U'].data['x'] = F.randn((3, 3))
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y')) g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
if __name__ == '__main__': if __name__ == '__main__':
test_isolated_nodes() # test_isolated_nodes("int32")
test_nx_conversion() # test_nx_conversion()
test_batch_setter_getter() # test_batch_setter_getter("int32")
test_batch_setter_autograd() # test_batch_recv("int64")
test_batch_send() test_apply_edges("int32")
test_batch_recv() # test_batch_setter_autograd()
test_apply_nodes() # test_batch_send()
test_apply_edges() # test_batch_recv()
test_update_routines() # test_apply_nodes()
test_recv_0deg() # test_apply_edges()
test_recv_0deg_newfld() # test_update_routines()
test_update_all_0deg() # test_recv_0deg()
test_pull_0deg() # test_recv_0deg_newfld()
test_send_multigraph() # test_update_all_0deg()
test_dynamic_addition() # test_pull_0deg()
test_repr() # test_send_multigraph()
test_group_apply_edges() # test_dynamic_addition()
test_local_var() # test_repr()
test_local_scope() # test_group_apply_edges()
test_issue_1088() # test_local_var()
# test_local_scope()
# test_issue_1088()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment