Unverified Commit dc8ca88e authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Refactor] Explicit dtype for HeteroGraph (#1467)



* 111

* 111

* lint

* lint

* lint

* lint

* fix

* lint

* try

* fix

* lint

* lint

* test

* fix

* ttt

* test

* fix

* fix

* fix

* mxnet

* 111

* fix 64bits computation

* pylint

* roll back

* fix

* lint

* fix hetero_from_relations

* remove index_dtype in to_homo and to_hetero

* fix

* fix

* fix

* fix

* remove default

* fix

* lint

* fix

* fix error message

* fix error

* lint

* macro dispatch

* try

* lint

* remove nbits

* error message

* fix

* fix

* lint

* lint

* lint

* fix

* lint

* fix

* fix random walk

* lint

* lint

* fix

* fix

* fix

* lint

* fix

* lint
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent de34e15a
......@@ -110,7 +110,7 @@ def schedule_recv(graph,
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
# sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
recv_nodes = utils.toindex(recv_nodes, graph.gidx.dtype)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# reduce
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
......@@ -161,7 +161,7 @@ def schedule_snr(graph,
"""
u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
recv_nodes = utils.toindex(recv_nodes, graph.gidx.dtype)
# create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, 'dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
......@@ -216,13 +216,13 @@ def schedule_update_all(graph,
if graph.num_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.num_dst()))
nodes = utils.toindex(slice(0, graph.num_dst()), graph.gidx.dtype)
schedule_apply_nodes(nodes, apply_func, graph.dstframe,
inplace=False, outframe=outframe,
ntype=graph.canonical_etype[-1])
else:
eid = utils.toindex(slice(0, graph.num_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst())) # ALL
eid = utils.toindex(slice(0, graph.num_edges()), graph.gidx.dtype) # ALL
recv_nodes = utils.toindex(slice(0, graph.num_dst()), graph.gidx.dtype) # ALL
# create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
......@@ -484,8 +484,9 @@ def schedule_pull(graph,
schedule_apply_nodes(pull_nodes, apply_func, graph.dstframe, inplace,
outframe, ntype=graph.canonical_etype[-1])
else:
# TODO(Allen): Change operation to dgl operation
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
pull_nodes = utils.toindex(pull_nodes, graph.gidx.dtype)
# create vars
var_dst_nf = var.FEAT_DICT(graph.dstframe, name='dst_nf')
var_out_nf = var_dst_nf if outframe is None else var.FEAT_DICT(outframe, name='out_nf')
......@@ -953,7 +954,7 @@ def _gen_send_reduce(
return var_out
else:
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data)))
mid = utils.toindex(slice(0, len(var_v.data)), var_v.data.dtype)
db.gen_degree_bucketing_schedule(rfunc, mid, var_v.data,
reduce_nodes, var_dst_nf, var_mf,
var_out, ntype=canonical_etype[-1])
......
......@@ -463,6 +463,7 @@ def metapath_reachable_graph(g, metapath):
A homogeneous or bipartite graph.
"""
adj = 1
index_dtype = g._idtype_str
for etype in metapath:
adj = adj * g.adj(etype=etype, scipy_fmt='csr', transpose=True)
......@@ -471,9 +472,9 @@ def metapath_reachable_graph(g, metapath):
dsttype = g.to_canonical_etype(metapath[-1])[2]
if srctype == dsttype:
assert adj.shape[0] == adj.shape[1]
new_g = graph(adj, ntype=srctype)
new_g = graph(adj, ntype=srctype, index_dtype=index_dtype)
else:
new_g = bipartite(adj, utype=srctype, vtype=dsttype)
new_g = bipartite(adj, utype=srctype, vtype=dsttype, index_dtype=index_dtype)
for key, value in g.nodes[srctype].data.items():
new_g.nodes[srctype].data[key] = value
......@@ -744,14 +745,16 @@ def compact_graphs(graphs, always_preserve=None):
# Ensure the node types are ordered the same.
# TODO(BarclayII): we ideally need to remove this constraint.
ntypes = graphs[0].ntypes
graph_dtype = graphs[0]._graph.dtype()
graph_dtype = graphs[0]._idtype_str
graph_ctx = graphs[0]._graph.ctx()
for g in graphs:
assert ntypes == g.ntypes, \
("All graphs should have the same node types in the same order, got %s and %s" %
ntypes, g.ntypes)
assert graph_dtype == g._graph.dtype(), "Graph data type mismatch"
assert graph_ctx == g._graph.ctx(), "Graph device mismatch"
assert graph_dtype == g._idtype_str, "Expect graph data type to be {}, but got {}".format(
graph_dtype, g._idtype_str)
assert graph_ctx == g._graph.ctx(), "Expect graph device to be {}, but got {}".format(
graph_ctx, g._graph.ctx())
# Process the dictionary or tensor of "always preserve" nodes
if always_preserve is None:
......@@ -919,7 +922,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
if nodes is not None:
dst_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes))
else:
dst_nodes_nd.append(nd.NULL)
dst_nodes_nd.append(nd.NULL[g._idtype_str])
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src)
......@@ -935,7 +938,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
new_graph.dstnodes[ntype].data[NID] = dst_nodes[ntype]
else:
# For empty dst node sets, still create empty mapping arrays.
new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=F.int64)
new_graph.dstnodes[ntype].data[NID] = F.tensor([], dtype=g.idtype)
for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data)
......@@ -970,8 +973,12 @@ def remove_edges(g, edge_ids):
"Graph has more than one edge type; specify a dict for edge_id instead.")
edge_ids = {g.canonical_etypes[0]: edge_ids}
edge_ids_nd = [nd.NULL] * len(g.etypes)
edge_ids_nd = [nd.NULL[g._idtype_str]] * len(g.etypes)
for key, value in edge_ids.items():
if value.dtype != g.idtype:
# if didn't check, this function still works, but returns wrong result
raise utils.InconsistentDtypeException("Expect edge id tensors({}) to have \
the same index type as graph({})".format(value.dtype, g.idtype))
edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value)
new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd)
......@@ -1018,9 +1025,9 @@ def in_subgraph(g, nodes):
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
nodes_all_types.append(utils.toindex(nodes[ntype], g._idtype_str).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
nodes_all_types.append(nd.NULL[g._idtype_str])
subgidx = _CAPI_DGLInSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges
......@@ -1057,9 +1064,9 @@ def out_subgraph(g, nodes):
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
nodes_all_types.append(utils.toindex(nodes[ntype], g._idtype_str).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
nodes_all_types.append(nd.NULL[g._idtype_str])
subgidx = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges
......@@ -1135,7 +1142,7 @@ def to_simple(g, return_counts='count', writeback_mapping=None):
def as_heterograph(g, ntype='_U', etype='_E'):
"""Convert a DGLGraph to a DGLHeteroGraph with one node and edge type.
Node and edge features are preserved.
Node and edge features are preserved. Returns 64 bits graph
Parameters
----------
......
......@@ -9,9 +9,19 @@ from .base import DGLError
from . import backend as F
from . import ndarray as nd
class InconsistentDtypeException(DGLError):
"""Exception class for inconsistent dtype between graph and tensor"""
def __init__(self, msg='', *args, **kwargs): #pylint: disable=W1113
prefix_message = 'DGL now requires the input tensor to have\
the same dtype as the graph index\'s dtype(which you can get by g.idype). '
super().__init__(prefix_message + msg, *args, **kwargs)
class Index(object):
"""Index class that can be easily converted to list/tensor."""
def __init__(self, data):
def __init__(self, data, dtype="int64"):
assert dtype in ['int32', 'int64']
self.dtype = dtype
self._initialize_data(data)
def _initialize_data(self, data):
......@@ -43,18 +53,22 @@ class Index(object):
def _dispatch(self, data):
"""Store data based on its type."""
if F.is_tensor(data):
if F.dtype(data) != F.int64:
raise DGLError('Index data must be an int64 vector, but got: %s' % str(data))
if F.dtype(data) != F.data_type_dict[self.dtype]:
raise InconsistentDtypeException('Index data specified as %s, but got: %s' %
(self.dtype,
F.reverse_data_type_dict[F.dtype(data)]))
if len(F.shape(data)) > 1:
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data))
raise InconsistentDtypeException('Index data must be 1D int32/int64 vector,\
but got shape: %s' % str(F.shape(data)))
if len(F.shape(data)) == 0:
# a tensor of one int
self._dispatch(int(data))
else:
self._user_tensor_data[F.context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data))
if not (data.dtype == self.dtype and len(data.shape) == 1):
raise InconsistentDtypeException('Index data must be 1D %s vector, but got: %s' %
(self.dtype, data.dtype))
self._dgl_tensor_data = data
elif isinstance(data, slice):
# save it in the _pydata temporarily; materialize it if `tonumpy` is called
......@@ -63,7 +77,7 @@ class Index(object):
self._slice_data = slice(data.start, data.stop)
else:
try:
data = np.asarray(data, dtype=np.int64)
data = np.asarray(data, dtype=self.dtype)
except Exception: # pylint: disable=broad-except
raise DGLError('Error index data: %s' % str(data))
if data.ndim == 0: # scalar array
......@@ -79,7 +93,7 @@ class Index(object):
if self._pydata is None:
if self._slice_data is not None:
slc = self._slice_data
self._pydata = np.arange(slc.start, slc.stop).astype(np.int64)
self._pydata = np.arange(slc.start, slc.stop).astype(self.dtype)
elif self._dgl_tensor_data is not None:
self._pydata = self._dgl_tensor_data.asnumpy()
else:
......@@ -128,12 +142,13 @@ class Index(object):
def __getstate__(self):
if self._slice_data is not None:
# the index can be represented by a slice
return self._slice_data
return self._slice_data, self.dtype
else:
return self.tousertensor()
return self.tousertensor(), self.dtype
def __setstate__(self, state):
self._initialize_data(state)
data, self.dtype = state
self._initialize_data(data)
def get_items(self, index):
"""Return values at given positions of an Index
......@@ -155,18 +170,22 @@ class Index(object):
# the provided index is not a slice
tensor = self.tousertensor()
index = index.tousertensor()
return Index(F.gather_row(tensor, index))
# TODO(Allen): Change F.gather_row to dgl operation
return Index(F.gather_row(tensor, index), self.dtype)
elif self._slice_data is None:
# the current index is not a slice but the provided is a slice
tensor = self.tousertensor()
index = index._slice_data
return Index(F.narrow_row(tensor, index.start, index.stop))
# TODO(Allen): Change F.narrow_row to dgl operation
return Index(F.astype(F.narrow_row(tensor, index.start, index.stop),
F.data_type_dict[self.dtype]),
self.dtype)
else:
# both self and index wrap a slice object, then return another
# Index wrapping a slice
start = self._slice_data.start
index = index._slice_data
return Index(slice(start + index.start, start + index.stop))
return Index(slice(start + index.start, start + index.stop), self.dtype)
def set_items(self, index, value):
"""Set values at given positions of an Index. Set is not done in place,
......@@ -191,7 +210,7 @@ class Index(object):
value = F.full_1d(len(index), value, dtype=F.int64, ctx=F.cpu())
else:
value = value.tousertensor()
return Index(F.scatter_row(tensor, index, value))
return Index(F.scatter_row(tensor, index, value), self.dtype)
def append_zeros(self, num):
"""Append zeros to an Index
......@@ -205,24 +224,24 @@ class Index(object):
return self
new_items = F.zeros((num,), dtype=F.int64, ctx=F.cpu())
if len(self) == 0:
return Index(new_items)
return Index(new_items, self.dtype)
else:
tensor = self.tousertensor()
tensor = F.cat((tensor, new_items), dim=0)
return Index(tensor)
return Index(tensor, self.dtype)
def nonzero(self):
"""Return the nonzero positions"""
tensor = self.tousertensor()
mask = F.nonzero_1d(tensor != 0)
return Index(mask)
return Index(mask, self.dtype)
def has_nonzero(self):
"""Check if there is any nonzero value in this Index"""
tensor = self.tousertensor()
return F.sum(tensor, 0) > 0
def toindex(data):
def toindex(data, dtype='int64'):
"""Convert the given data to Index object.
Parameters
......@@ -239,16 +258,17 @@ def toindex(data):
--------
Index
"""
return data if isinstance(data, Index) else Index(data)
return data if isinstance(data, Index) else Index(data, dtype)
def zero_index(size):
def zero_index(size, dtype="int64"):
"""Create a index with provided size initialized to zero
Parameters
----------
size: int
"""
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu()))
return Index(F.zeros((size,), dtype=F.data_type_dict[dtype], ctx=F.cpu()),
dtype=dtype)
def set_diff(ar1, ar2):
"""Find the set difference of two index arrays.
......
......@@ -277,7 +277,8 @@ class HeteroNodeView(object):
def __call__(self, ntype=None):
"""Return the nodes."""
return F.arange(0, self._graph.number_of_nodes(ntype))
return F.arange(0, self._graph.number_of_nodes(ntype),
dtype=self._graph._idtype_str)
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
......
......@@ -115,7 +115,7 @@ template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L);
IdArray ret = NewIdArray(2 * L, DLContext{kDLCPU, 0}, arr1->dtype.bits);
const IdType* arr1_data = static_cast<IdType*>(arr1->data);
const IdType* arr2_data = static_cast<IdType*>(arr2->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
......@@ -173,7 +173,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
}
}
// map array
IdArray maparr = NewIdArray(newid);
IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first;
......
......@@ -10,6 +10,13 @@
#include <vector>
#include <unordered_map>
#include <utility>
#include "../../c_api_common.h"
#define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK(VAR1->dtype == VAR2->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \
<< (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype;
namespace dgl {
......
......@@ -22,6 +22,7 @@ void CSRRemoveConsecutive(
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t n_entries = entries->shape[0];
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
......@@ -54,6 +55,7 @@ void CSRRemoveShuffled(
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
......@@ -77,6 +79,7 @@ void CSRRemoveShuffled(
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0];
const int64_t n_entries = entries->shape[0];
if (n_entries == 0)
......
......@@ -43,6 +43,8 @@ template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
CHECK_SAME_DTYPE(csr.indices, row);
CHECK_SAME_DTYPE(csr.indices, col);
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
const auto rstlen = std::max(rowlen, collen);
......@@ -98,6 +100,7 @@ template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -194,6 +197,8 @@ template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
......@@ -261,6 +266,8 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
// TODO(minjie): more efficient implementation for matrix without duplicate entries
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
......@@ -448,6 +455,7 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
......@@ -494,6 +502,8 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0];
......
......@@ -106,6 +106,7 @@ template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data);
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
......@@ -171,8 +172,10 @@ template NDArray COOGetData<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices(
COOMatrix coo, NDArray rows, NDArray cols) {
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows);
CHECK_SAME_DTYPE(coo.col, cols);
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
const int64_t len = std::max(rowlen, collen);
......
......@@ -235,12 +235,34 @@ HeteroSubgraph HeteroGraph::EdgeSubgraph(
}
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etypes) const {
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
const std::vector<dgl_type_t>& etypes) const {
const int64_t bits = NumBits();
if (bits == 32) {
return FlattenImpl<int32_t>(etypes);
} else if (bits == 64) {
return FlattenImpl<int64_t>(etypes);
}
}
template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_id_t> result_src, result_dst;
std::vector<IdType> result_src, result_dst;
std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<dgl_id_t> induced_srcid, induced_eid, induced_dstid;
std::vector<IdType> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
......@@ -261,7 +283,6 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
dsttype_set.push_back(dsttype);
}
}
// Sort the node types so that we can compare the sets and decide whether a homograph
// should be returned.
std::sort(srctype_set.begin(), srctype_set.end());
......@@ -301,9 +322,9 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const dgl_id_t* edges_eid_data = static_cast<const dgl_id_t*>(edges.id->data);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
const IdType* edges_eid_data = static_cast<const IdType*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
......
......@@ -202,6 +202,9 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
private:
// To create empty class
friend class Serializer;
......@@ -214,6 +217,15 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_;
/*! \brief template class for Flatten operation
*
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param etypes vector of etypes to be falttened
* \return pointer of FlattenedHeteroGraphh
*/
template <class IdType>
FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const;
};
} // namespace dgl
......
......@@ -3,10 +3,12 @@
* \file graph/heterograph_capi.cc
* \brief Heterograph CAPI bindings.
*/
#include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "../c_api_common.h"
#include "./heterograph.h"
using namespace dgl::runtime;
......@@ -409,7 +411,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
HeteroGraphPtr bhg_ptr = hg.sptr();
auto hg_ptr = std::dynamic_pointer_cast<HeteroGraph>(bhg_ptr);
HeteroGraphPtr hg_new;
if (hg_ptr) {
hg_new = HeteroGraph::AsNumBits(hg_ptr, bits);
} else {
hg_new = UnitGraph::AsNumBits(bhg_ptr, bits);
}
*rv = HeteroGraphRef(hg_new);
});
......@@ -429,13 +438,22 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
<< "Expect graph list has at least one graph";
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to batch have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
}
auto hgptr = DisjointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
ATEN_ID_BITS_SWITCH(bits, IdType, {
auto hgptr =
DisjointUnionHeteroGraph<IdType>(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
......@@ -443,8 +461,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
const auto& ret = DisjointPartitionHeteroBySizes(
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
const int64_t bits = hg->NumBits();
std::vector<HeteroGraphPtr> ret;
ATEN_ID_BITS_SWITCH(bits, IdType, {
ret = DisjointPartitionHeteroBySizes<IdType>(hg->meta_graph(), hg.sptr(),
vertex_sizes, edge_sizes);
});
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
......
......@@ -7,7 +7,9 @@
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include "../../c_api_common.h"
#include "../unit_graph.h"
#include "randomwalk.h"
using namespace dgl::runtime;
......@@ -21,14 +23,23 @@ namespace {
/*!
* \brief Random walk based on the given metapath.
*
*
* \tparam IdType Index dtype of graph
* \param hg The heterograph
* \param etypes The metapath as an array of edge type IDs
* \param seeds The array of starting vertices for random walks
* \param num_traces Number of traces to generate for each starting vertex
* \note The metapath should have the same starting and ending node type.
*/
template <typename T>
RandomWalkTracesPtr MetapathRandomWalk(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces);
template <>
RandomWalkTracesPtr MetapathRandomWalk<int64_t>(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
......@@ -74,10 +85,64 @@ RandomWalkTracesPtr MetapathRandomWalk(
return RandomWalkTracesPtr(tl);
}
/*!
* \brief This is a patch function for int32 HeteroGraph
* TODO: Refactor this with CSR and COO operations
*/
template <>
RandomWalkTracesPtr MetapathRandomWalk<int32_t>(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces) {
const auto metagraph = hg->meta_graph();
uint64_t num_etypes = etypes->shape[0];
uint64_t num_seeds = seeds->shape[0];
const dgl_type_t *etype_data = static_cast<dgl_type_t *>(etypes->data);
const int32_t *seed_data = static_cast<int32_t *>(seeds->data);
std::vector<int32_t> vertices;
std::vector<size_t> trace_lengths, trace_counts;
// TODO(quan): use omp to parallelize this loop
for (uint64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
int curr_num_traces = 0;
for (; curr_num_traces < num_traces; ++curr_num_traces) {
int32_t curr = seed_data[seed_id];
size_t trace_length = 0;
for (size_t i = 0; i < num_etypes; ++i) {
auto ug = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype_data[i]));
CHECK_NOTNULL(ug);
const auto &succ = ug->SuccVec32(etype_data[i], curr);
if (succ.size() == 0)
break;
curr = succ[RandomEngine::ThreadLocal()->RandInt(succ.size())];
vertices.push_back(curr);
++trace_length;
}
trace_lengths.push_back(trace_length);
}
trace_counts.push_back(curr_num_traces);
}
RandomWalkTraces *tl = new RandomWalkTraces;
tl->vertices = VecToIdArray(vertices);
tl->trace_lengths = VecToIdArray(trace_lengths);
tl->trace_counts = VecToIdArray(trace_counts);
return RandomWalkTracesPtr(tl);
}
}; // namespace
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef hg = args[0];
const IdArray etypes = args[1];
const IdArray seeds = args[2];
......@@ -89,7 +154,11 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "MetapathRandomWalk only support CPU sampling";
const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces);
const int64_t bits = hg->NumBits();
RandomWalkTracesPtr tl;
ATEN_ID_BITS_SWITCH(bits, IdType, {
tl = MetapathRandomWalk<IdType>(hg.sptr(), etypes, seeds, num_traces);
});
*rv = RandomWalkTracesRef(tl);
});
......
......@@ -8,6 +8,7 @@ using namespace dgl::runtime;
namespace dgl {
template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
......@@ -19,16 +20,16 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
dgl_id_t src_offset = 0, dst_offset = 0;
std::vector<dgl_id_t> result_src, result_dst;
IdType src_offset = 0, dst_offset = 0;
std::vector<IdType> result_src, result_dst;
// Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
// Loop over all edges
for (size_t j = 0; j < num_edges; ++j) {
......@@ -41,11 +42,9 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
dst_offset += cg->NumVertices(dst_vtype);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
src_offset,
dst_offset,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
(src_vtype == dst_vtype) ? 1 : 2, src_offset, dst_offset,
aten::VecToIdArray(result_src, sizeof(IdType) * 8),
aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
......@@ -53,6 +52,13 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
template HeteroGraphPtr DisjointUnionHeteroGraph<int32_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template HeteroGraphPtr DisjointUnionHeteroGraph<int64_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
// Sanity check for vertex sizes
......@@ -102,11 +108,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
EdgeArray edges = batched_graph->Edges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
// Loop over all graphs to be unbatched
for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<dgl_id_t> result_src, result_dst;
std::vector<IdType> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) {
// TODO(mufei): Should use array operations to implement this.
......@@ -114,11 +120,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
vertex_sizes_data[src_vtype * batch_size + g],
vertex_sizes_data[dst_vtype * batch_size + g],
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
(src_vtype == dst_vtype) ? 1 : 2,
vertex_sizes_data[src_vtype * batch_size + g],
vertex_sizes_data[dst_vtype * batch_size + g],
aten::VecToIdArray(result_src, sizeof(IdType) * 8),
aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
rel_graphs[g].push_back(rgptr);
}
}
......@@ -133,4 +139,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
return rst;
}
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
} // namespace dgl
......@@ -645,6 +645,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid];
......@@ -652,9 +653,20 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return DGLIdIters(indices_data + start, indices_data + end);
}
DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
const int32_t start = indptr_data[vid];
const int32_t end = indptr_data[vid + 1];
return DGLIdIters32(indices_data + start, indices_data + end);
}
DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
const dgl_id_t start = indptr_data[vid];
......@@ -951,6 +963,13 @@ DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
return ptr->SuccVec(etype, vid);
}
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
CHECK_NOTNULL(ptr);
return ptr->SuccVec32(etype, vid);
}
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
......
......@@ -139,6 +139,9 @@ class UnitGraph : public BaseHeteroGraph {
DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;
// 32bit version functions, patch for SuccVec
DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;
DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;
DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;
......
......@@ -10,28 +10,29 @@
namespace dgl {
namespace sched {
template <class IdType>
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids,
const IdArray& recv_ids) {
auto n_msgs = msg_ids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
const int64_t* msg_id_data = static_cast<int64_t*>(msg_ids->data);
const int64_t* recv_id_data = static_cast<int64_t*>(recv_ids->data);
const IdType* vid_data = static_cast<IdType*>(vids->data);
const IdType* msg_id_data = static_cast<IdType*>(msg_ids->data);
const IdType* recv_id_data = static_cast<IdType*>(recv_ids->data);
// in edge: dst->msgs
std::unordered_map<int64_t, std::vector<int64_t>> in_edges;
for (int64_t i = 0; i < n_msgs; ++i) {
std::unordered_map<IdType, std::vector<IdType>> in_edges;
for (IdType i = 0; i < n_msgs; ++i) {
in_edges[vid_data[i]].push_back(msg_id_data[i]);
}
// bkt: deg->dsts
std::unordered_map<int64_t, std::vector<int64_t>> bkt;
std::unordered_map<IdType, std::vector<IdType>> bkt;
for (const auto& it : in_edges) {
bkt[it.second.size()].push_back(it.first);
}
std::unordered_set<int64_t> zero_deg_nodes;
for (int64_t i = 0; i < recv_ids->shape[0]; ++i) {
std::unordered_set<IdType> zero_deg_nodes;
for (IdType i = 0; i < recv_ids->shape[0]; ++i) {
if (in_edges.find(recv_id_data[i]) == in_edges.end()) {
zero_deg_nodes.insert(recv_id_data[i]);
}
......@@ -39,9 +40,9 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
auto n_zero_deg = zero_deg_nodes.size();
// calc output size
int64_t n_deg = bkt.size();
int64_t n_dst = in_edges.size();
int64_t n_mid_sec = bkt.size(); // zero deg won't affect message size
IdType n_deg = bkt.size();
IdType n_dst = in_edges.size();
IdType n_mid_sec = bkt.size(); // zero deg won't affect message size
if (n_zero_deg > 0) {
n_deg += 1;
n_dst += n_zero_deg;
......@@ -53,16 +54,16 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
IdArray nid_section = IdArray::Empty({n_deg}, vids->dtype, vids->ctx);
IdArray mids = IdArray::Empty({n_msgs}, vids->dtype, vids->ctx);
IdArray mid_section = IdArray::Empty({n_mid_sec}, vids->dtype, vids->ctx);
int64_t* deg_ptr = static_cast<int64_t*>(degs->data);
int64_t* nid_ptr = static_cast<int64_t*>(nids->data);
int64_t* nsec_ptr = static_cast<int64_t*>(nid_section->data);
int64_t* mid_ptr = static_cast<int64_t*>(mids->data);
int64_t* msec_ptr = static_cast<int64_t*>(mid_section->data);
IdType* deg_ptr = static_cast<IdType*>(degs->data);
IdType* nid_ptr = static_cast<IdType*>(nids->data);
IdType* nsec_ptr = static_cast<IdType*>(nid_section->data);
IdType* mid_ptr = static_cast<IdType*>(mids->data);
IdType* msec_ptr = static_cast<IdType*>(mid_section->data);
// fill in bucketing ordering
for (const auto& it : bkt) { // for each bucket
const int64_t deg = it.first;
const int64_t bucket_size = it.second.size();
const IdType deg = it.first;
const IdType bucket_size = it.second.size();
*deg_ptr++ = deg;
*nsec_ptr++ = bucket_size;
*msec_ptr++ = deg * bucket_size;
......@@ -92,67 +93,82 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
return std::move(ret);
}
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, const IdArray& vids,
const IdArray& eids) {
auto n_edge = eids->shape[0];
const int64_t* eid_data = static_cast<int64_t*>(eids->data);
const int64_t* uid_data = static_cast<int64_t*>(uids->data);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
// node2edge: group_by nodes uid -> (eid, the other end vid)
std::unordered_map<int64_t,
std::vector<std::pair<int64_t, int64_t>>> node2edge;
for (int64_t i = 0; i < n_edge; ++i) {
node2edge[uid_data[i]].emplace_back(eid_data[i], vid_data[i]);
template std::vector<IdArray> DegreeBucketing<int32_t>(const IdArray& msg_ids,
const IdArray& vids,
const IdArray& recv_ids);
template std::vector<IdArray> DegreeBucketing<int64_t>(const IdArray& msg_ids,
const IdArray& vids,
const IdArray& recv_ids);
template <class IdType>
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids,
const IdArray& vids,
const IdArray& eids) {
auto n_edge = eids->shape[0];
const IdType* eid_data = static_cast<IdType*>(eids->data);
const IdType* uid_data = static_cast<IdType*>(uids->data);
const IdType* vid_data = static_cast<IdType*>(vids->data);
// node2edge: group_by nodes uid -> (eid, the other end vid)
std::unordered_map<IdType, std::vector<std::pair<IdType, IdType>>> node2edge;
for (IdType i = 0; i < n_edge; ++i) {
node2edge[uid_data[i]].emplace_back(eid_data[i], vid_data[i]);
}
// bkt: deg -> group_by node uid
std::unordered_map<IdType, std::vector<IdType>> bkt;
for (const auto& it : node2edge) {
bkt[it.second.size()].push_back(it.first);
}
// number of unique degree
IdType n_deg = bkt.size();
// initialize output
IdArray degs = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
IdArray new_uids = IdArray::Empty({n_edge}, uids->dtype, uids->ctx);
IdArray new_vids = IdArray::Empty({n_edge}, vids->dtype, vids->ctx);
IdArray new_eids = IdArray::Empty({n_edge}, eids->dtype, eids->ctx);
IdArray sections = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
IdType* deg_ptr = static_cast<IdType*>(degs->data);
IdType* uid_ptr = static_cast<IdType*>(new_uids->data);
IdType* vid_ptr = static_cast<IdType*>(new_vids->data);
IdType* eid_ptr = static_cast<IdType*>(new_eids->data);
IdType* sec_ptr = static_cast<IdType*>(sections->data);
// fill in bucketing ordering
for (const auto& it : bkt) { // for each bucket
// degree of this bucket
const IdType deg = it.first;
// number of edges in this bucket
const IdType bucket_size = it.second.size();
*deg_ptr++ = deg;
*sec_ptr++ = deg * bucket_size;
for (const auto u : it.second) { // for uid in this bucket
for (const auto& pair : node2edge[u]) { // for each edge of uid
*uid_ptr++ = u;
*vid_ptr++ = pair.second;
*eid_ptr++ = pair.first;
}
}
}
// bkt: deg -> group_by node uid
std::unordered_map<int64_t, std::vector<int64_t>> bkt;
for (const auto& it : node2edge) {
bkt[it.second.size()].push_back(it.first);
}
// number of unique degree
int64_t n_deg = bkt.size();
// initialize output
IdArray degs = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
IdArray new_uids = IdArray::Empty({n_edge}, uids->dtype, uids->ctx);
IdArray new_vids = IdArray::Empty({n_edge}, vids->dtype, vids->ctx);
IdArray new_eids = IdArray::Empty({n_edge}, eids->dtype, eids->ctx);
IdArray sections = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);
int64_t* deg_ptr = static_cast<int64_t*>(degs->data);
int64_t* uid_ptr = static_cast<int64_t*>(new_uids->data);
int64_t* vid_ptr = static_cast<int64_t*>(new_vids->data);
int64_t* eid_ptr = static_cast<int64_t*>(new_eids->data);
int64_t* sec_ptr = static_cast<int64_t*>(sections->data);
std::vector<IdArray> ret;
ret.push_back(std::move(degs));
ret.push_back(std::move(new_uids));
ret.push_back(std::move(new_vids));
ret.push_back(std::move(new_eids));
ret.push_back(std::move(sections));
// fill in bucketing ordering
for (const auto& it : bkt) { // for each bucket
// degree of this bucket
const int64_t deg = it.first;
// number of edges in this bucket
const int64_t bucket_size = it.second.size();
*deg_ptr++ = deg;
*sec_ptr++ = deg * bucket_size;
for (const auto u : it.second) { // for uid in this bucket
for (const auto& pair : node2edge[u]) { // for each edge of uid
*uid_ptr++ = u;
*vid_ptr++ = pair.second;
*eid_ptr++ = pair.first;
}
}
}
return std::move(ret);
}
std::vector<IdArray> ret;
ret.push_back(std::move(degs));
ret.push_back(std::move(new_uids));
ret.push_back(std::move(new_vids));
ret.push_back(std::move(new_eids));
ret.push_back(std::move(sections));
template std::vector<IdArray> GroupEdgeByNodeDegree<int32_t>(
const IdArray& uids, const IdArray& vids, const IdArray& eids);
return std::move(ret);
}
template std::vector<IdArray> GroupEdgeByNodeDegree<int64_t>(
const IdArray& uids, const IdArray& vids, const IdArray& eids);
} // namespace sched
......
......@@ -3,9 +3,11 @@
* \file scheduler/scheduler_apis.cc
* \brief DGL scheduler APIs
*/
#include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/scheduler.h>
#include "../c_api_common.h"
#include "../array/cpu/array_utils.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue;
......@@ -14,11 +16,16 @@ using dgl::runtime::NDArray;
namespace dgl {
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray msg_ids = args[0];
const IdArray vids = args[1];
const IdArray nids = args[2];
*rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(msg_ids, vids, nids));
CHECK_SAME_DTYPE(msg_ids, vids);
CHECK_SAME_DTYPE(msg_ids, nids);
ATEN_ID_TYPE_SWITCH(msg_ids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc(
sched::DegreeBucketing<IdType>(msg_ids, vids, nids));
});
});
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
......@@ -26,8 +33,12 @@ DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
const IdArray uids = args[0];
const IdArray vids = args[1];
const IdArray eids = args[2];
*rv = ConvertNDArrayVectorToPackedFunc(
sched::GroupEdgeByNodeDegree(uids, vids, eids));
CHECK_SAME_DTYPE(uids, vids);
CHECK_SAME_DTYPE(uids, eids);
ATEN_ID_TYPE_SWITCH(uids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc(
sched::GroupEdgeByNodeDegree<IdType>(uids, vids, eids));
});
});
} // namespace dgl
......@@ -2,6 +2,7 @@ import dgl
import backend as F
from dgl.base import ALL
from utils import parametrize_dtype
def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes
......@@ -32,18 +33,19 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
def test_batching_hetero_topology():
@parametrize_dtype
def test_batching_hetero_topology(index_dtype):
"""Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
})
}, index_dtype=index_dtype)
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
})
}, index_dtype=index_dtype)
bg = dgl.batch_hetero([g1, g2])
assert bg.ntypes == g2.ntypes
......@@ -90,21 +92,23 @@ def test_batching_hetero_topology():
check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4)
def test_batching_hetero_and_batched_hetero_topology():
@parametrize_dtype
def test_batching_hetero_and_batched_hetero_topology(index_dtype):
"""Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
}, index_dtype=index_dtype)
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
}, index_dtype=index_dtype)
bg1 = dgl.batch_hetero([g1, g2])
g3 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1)],
('user', 'plays', 'game'): [(1, 0)]
})
}, index_dtype=index_dtype)
bg2 = dgl.batch_hetero([bg1, g3])
assert bg2.ntypes == g3.ntypes
assert bg2.etypes == g3.etypes
......@@ -149,12 +153,13 @@ def test_batching_hetero_and_batched_hetero_topology():
check_equivalence_between_heterographs(g2, g5)
check_equivalence_between_heterographs(g3, g6)
def test_batched_features():
@parametrize_dtype
def test_batched_features(index_dtype):
"""Test the features of batched DGLHeteroGraphs"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
}, index_dtype=index_dtype)
g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g1.nodes['game'].data['h1'] = F.tensor([[0.]])
......@@ -166,7 +171,7 @@ def test_batched_features():
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
}, index_dtype=index_dtype)
g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g2.nodes['game'].data['h1'] = F.tensor([[0.]])
......
......@@ -6,6 +6,9 @@ import dgl
import networkx as nx
from collections import defaultdict as ddict
import unittest
import pytest
import inspect
from utils import parametrize_dtype
D = 5
reduce_msg_shapes = set()
......@@ -25,7 +28,7 @@ def reduce_func(nodes):
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False):
def generate_graph(index_dtype='int64', grad=False):
'''
s, d, eid
0, 1, 0
......@@ -47,7 +50,7 @@ def generate_graph(grad=False):
9, 0, 16
'''
g = dgl.graph([(0,1), (1,9), (0,2), (2,9), (0,3), (3,9), (0,4), (4,9),
(0,5), (5,9), (0,6), (6,9), (0,7), (7,9), (0,8), (8,9), (9,0)])
(0,5), (5,9), (0,6), (6,9), (0,7), (7,9), (0,8), (8,9), (9,0)], index_dtype=index_dtype)
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
......@@ -60,27 +63,35 @@ def generate_graph(grad=False):
g.set_e_initializer(dgl.init.zero_initializer)
return g
def test_isolated_nodes():
g = dgl.graph([(0, 1), (1, 2)], num_nodes=5)
@parametrize_dtype
def test_isolated_nodes(index_dtype):
g = dgl.graph([(0, 1), (1, 2)], num_nodes=5, index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes() == 5
# Test backward compatibility
g = dgl.graph([(0, 1), (1, 2)], card=5)
g = dgl.graph([(0, 1), (1, 2)], card=5, index_dtype=index_dtype)
assert g.number_of_nodes() == 5
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', num_nodes=(5, 7))
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays',
'game', num_nodes=(5, 7), index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7
# Test backward compatibility
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays', 'game', card=(5, 7))
g = dgl.bipartite([(0, 2), (0, 3), (1, 2)], 'user', 'plays',
'game', card=(5, 7), index_dtype=index_dtype)
assert g._idtype_str == index_dtype
assert g.number_of_nodes('user') == 5
assert g.number_of_nodes('game') == 7
def test_batch_setter_getter():
@parametrize_dtype
def test_batch_setter_getter(index_dtype):
def _pfc(x):
return list(F.zerocopy_to_numpy(x)[:,0])
g = generate_graph()
g = generate_graph(index_dtype)
# set all nodes
g.ndata['h'] = F.zeros((10, D))
assert F.allclose(g.ndata['h'], F.zeros((10, D)))
......@@ -90,11 +101,11 @@ def test_batch_setter_getter():
assert len(g.ndata) == old_len - 1
g.ndata['h'] = F.zeros((10, D))
# set partial nodes
u = F.tensor([1, 3, 5])
u = F.tensor([1, 3, 5], F.data_type_dict[index_dtype])
g.nodes[u].data['h'] = F.ones((3, D))
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = F.tensor([1, 2, 3])
u = F.tensor([1, 2, 3], F.data_type_dict[index_dtype])
assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
'''
......@@ -126,42 +137,44 @@ def test_batch_setter_getter():
assert len(g.edata) == old_len - 1
g.edata['l'] = F.zeros((17, D))
# set partial edges (many-many)
u = F.tensor([0, 0, 2, 5, 9])
v = F.tensor([1, 3, 9, 9, 0])
u = F.tensor([0, 0, 2, 5, 9], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 3, 9, 9, 0], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((5, D))
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.edata['l']) == truth
# set partial edges (many-one)
u = F.tensor([3, 4, 6])
v = F.tensor([9])
u = F.tensor([3, 4, 6], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((3, D))
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.edata['l']) == truth
# set partial edges (one-many)
u = F.tensor([0])
v = F.tensor([4, 5, 6])
u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([4, 5, 6], dtype=F.data_type_dict[index_dtype])
g.edges[u, v].data['l'] = F.ones((3, D))
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.edata['l']) == truth
# get partial edges (many-many)
u = F.tensor([0, 6, 0])
v = F.tensor([6, 9, 7])
u = F.tensor([0, 6, 0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([6, 9, 7], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = F.tensor([5, 6, 7])
v = F.tensor([9])
u = F.tensor([5, 6, 7], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = F.tensor([0])
v = F.tensor([3, 4, 5])
u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([3, 4, 5], dtype=F.data_type_dict[index_dtype])
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
@parametrize_dtype
def test_batch_setter_autograd(index_dtype):
g = generate_graph(index_dtype=index_dtype, grad=True)
h1 = g.ndata['h']
# partial set
v = F.tensor([1, 2, 8])
v = F.tensor([1, 2, 8], F.data_type_dict[index_dtype])
hh = F.attach_grad(F.zeros((len(v), D)))
with F.record_grad():
g.nodes[v].data['h'] = hh
......@@ -170,7 +183,9 @@ def test_batch_setter_autograd():
assert F.array_equal(F.grad(h1)[:,0], F.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
assert F.array_equal(F.grad(hh)[:,0], F.tensor([2., 2., 2.]))
def test_nx_conversion():
@parametrize_dtype
def atest_nx_conversion(index_dtype):
# check conversion between networkx and DGLGraph
def _check_nx_feature(nxg, nf, ef):
......@@ -207,7 +222,7 @@ def test_nx_conversion():
n3 = F.randn((5, 4))
e1 = F.randn((4, 5))
e2 = F.randn((4, 7))
g = dgl.graph([(0,2),(1,4),(3,0),(4,3)])
g = dgl.graph([(0,2),(1,4),(3,0),(4,3)], index_dtype=index_dtype)
g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
g.edata.update({'e1': e1, 'e2': e2})
......@@ -219,7 +234,8 @@ def test_nx_conversion():
# convert to DGLGraph, nx graph has id in edge feature
# use id feature to test non-tensor copy
g = dgl.graph(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id'])
g = dgl.graph(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id'], index_dtype=index_dtype)
assert g._idtype_str == index_dtype
# check graph size
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
......@@ -289,61 +305,67 @@ def test_nx_conversion():
assert F.allclose(g.edata['h'], F.tensor([[1., 2.], [1., 2.],
[2., 3.], [2., 3.]]))
def test_batch_send():
g = generate_graph()
@parametrize_dtype
def test_batch_send(index_dtype):
g = generate_graph(index_dtype=index_dtype)
def _fmsg(edges):
assert tuple(F.shape(edges.src['h'])) == (5, D)
return {'m' : edges.src['h']}
# many-many send
u = F.tensor([0, 0, 0, 0, 0])
v = F.tensor([1, 2, 3, 4, 5])
u = F.tensor([0, 0, 0, 0, 0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg)
# one-many send
u = F.tensor([0])
v = F.tensor([1, 2, 3, 4, 5])
u = F.tensor([0], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg)
# many-one send
u = F.tensor([1, 2, 3, 4, 5])
v = F.tensor([9])
u = F.tensor([1, 2, 3, 4, 5], dtype=F.data_type_dict[index_dtype])
v = F.tensor([9], dtype=F.data_type_dict[index_dtype])
g.send((u, v), _fmsg)
def test_batch_recv():
@parametrize_dtype
def test_batch_recv(index_dtype):
# basic recv test
g = generate_graph()
u = F.tensor([0, 0, 0, 4, 5, 6])
v = F.tensor([1, 2, 3, 9, 9, 9])
g = generate_graph(index_dtype=index_dtype)
u = F.tensor([0, 0, 0, 4, 5, 6], dtype=F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 9, 9, 9], dtype=F.data_type_dict[index_dtype])
reduce_msg_shapes.clear()
g.send((u, v), message_func)
g.recv(F.unique(v), reduce_func, apply_node_func)
g.recv(F.astype(F.unique(v), F.data_type_dict[index_dtype]), reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_apply_nodes():
@parametrize_dtype
def test_apply_nodes(index_dtype):
def _upd(nodes):
return {'h' : nodes.data['h'] * 2}
g = generate_graph()
g = generate_graph(index_dtype=index_dtype)
old = g.ndata['h']
g.apply_nodes(_upd)
assert F.allclose(old * 2, g.ndata['h'])
u = F.tensor([0, 3, 4, 6])
u = F.tensor([0, 3, 4, 6], F.data_type_dict[index_dtype])
g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
assert F.allclose(F.gather_row(g.ndata['h'], u), F.zeros((4, D)))
def test_apply_edges():
@parametrize_dtype
def test_apply_edges(index_dtype):
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g = generate_graph(index_dtype=index_dtype)
old = g.edata['w']
g.apply_edges(_upd)
assert F.allclose(old * 2, g.edata['w'])
u = F.tensor([0, 0, 0, 4, 5, 6])
v = F.tensor([1, 2, 3, 9, 9, 9])
u = F.tensor([0, 0, 0, 4, 5, 6], F.data_type_dict[index_dtype])
v = F.tensor([1, 2, 3, 9, 9, 9], F.data_type_dict[index_dtype])
g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
eid = F.tensor(g.edge_ids(u, v))
eid = F.tensor(g.edge_ids(u, v), F.data_type_dict[index_dtype])
assert F.allclose(F.gather_row(g.edata['w'], eid), F.zeros((6, D)))
def test_update_routines():
g = generate_graph()
@parametrize_dtype
def test_update_routines(index_dtype):
g = generate_graph(index_dtype=index_dtype)
# send_and_recv
reduce_msg_shapes.clear()
......@@ -359,14 +381,14 @@ def test_update_routines():
pass
# pull
v = F.tensor([1, 2, 3, 9])
v = F.tensor([1, 2, 3, 9], F.data_type_dict[index_dtype])
reduce_msg_shapes.clear()
g.pull(v, message_func, reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# push
v = F.tensor([0, 1, 2, 3])
v = F.tensor([0, 1, 2, 3], F.data_type_dict[index_dtype])
reduce_msg_shapes.clear()
g.push(v, message_func, reduce_func, apply_node_func)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
......@@ -378,9 +400,10 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def test_recv_0deg():
@parametrize_dtype
def test_recv_0deg(index_dtype):
# test recv with 0deg nodes;
g = dgl.graph([(0,1)])
g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -412,9 +435,11 @@ def test_recv_0deg():
# non-0deg check: untouched
assert F.allclose(new[1], old[1])
def test_recv_0deg_newfld():
@parametrize_dtype
def test_recv_0deg_newfld(index_dtype):
# test recv with 0deg nodes; the reducer also creates a new field
g = dgl.graph([(0,1)])
g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -447,9 +472,10 @@ def test_recv_0deg_newfld():
# non-0deg check: not changed
assert F.allclose(new[1], F.full_1d(5, -1, F.int64))
def test_update_all_0deg():
@parametrize_dtype
def test_update_all_0deg(index_dtype):
# test#1
g = dgl.graph([(1,0), (2,0), (3,0), (4,0)])
g = dgl.graph([(1,0), (2,0), (3,0), (4,0)], index_dtype=index_dtype)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -470,7 +496,7 @@ def test_update_all_0deg():
assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0))
# test#2:
g = dgl.graph([], num_nodes=5)
g = dgl.graph([], num_nodes=5, index_dtype=index_dtype)
g.set_n_initializer(_init2, 'h')
g.ndata['h'] = old_repr
g.update_all(_message, _reduce, _apply)
......@@ -478,8 +504,9 @@ def test_update_all_0deg():
# should fallback to apply
assert F.allclose(new_repr, 2*old_repr)
def test_pull_0deg():
g = dgl.graph([(0,1)])
@parametrize_dtype
def test_pull_0deg(index_dtype):
g = dgl.graph([(0,1)], index_dtype=index_dtype)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -509,8 +536,9 @@ def test_pull_0deg():
# non-0deg check: not touched
assert F.allclose(new[1], old[1])
def test_send_multigraph():
g = dgl.graph([(0,1), (0,1), (0,1), (2,1)])
@parametrize_dtype
def test_send_multigraph(index_dtype):
g = dgl.graph([(0,1), (0,1), (0,1), (2,1)], index_dtype=index_dtype)
def _message_a(edges):
return {'a': edges.data['a']}
......@@ -607,9 +635,9 @@ def _test_dynamic_addition():
g.add_edge(2, 1, {'h1': F.randn((1, D))})
assert len(g.edata['h1']) == len(g.edata['h2'])
def test_repr():
G = dgl.graph([(0,1), (0,2), (1,2)], num_nodes=10)
@parametrize_dtype
def test_repr(index_dtype):
G = dgl.graph([(0,1), (0,2), (1,2)], num_nodes=10, index_dtype=index_dtype)
repr_string = G.__repr__()
print(repr_string)
G.ndata['x'] = F.zeros((10, 5))
......@@ -618,7 +646,8 @@ def test_repr():
print(repr_string)
def test_group_apply_edges():
@parametrize_dtype
def test_group_apply_edges(index_dtype):
def edge_udf(edges):
h = F.sum(edges.data['feat'] * (edges.src['h'] + edges.dst['h']), dim=2)
normalized_feat = F.softmax(h, dim=1)
......@@ -631,7 +660,7 @@ def test_group_apply_edges():
elist.append((1, v))
for v in [2, 3, 4, 5, 6, 7, 8]:
elist.append((2, v))
g = dgl.graph(elist)
g = dgl.graph(elist, index_dtype=index_dtype)
g.ndata['h'] = F.randn((g.number_of_nodes(), D))
g.edata['feat'] = F.randn((g.number_of_edges(), D))
......@@ -653,8 +682,9 @@ def test_group_apply_edges():
# test group by destination nodes
_test('dst')
def test_local_var():
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)])
@parametrize_dtype
def test_local_var(index_dtype):
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)], index_dtype=index_dtype)
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override
......@@ -710,8 +740,9 @@ def test_local_var():
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
def test_local_scope():
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)])
@parametrize_dtype
def test_local_scope(index_dtype):
g = dgl.graph([(0,1), (1,2), (2,3), (3,4)], index_dtype=index_dtype)
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
g.edata['w'] = F.zeros((g.number_of_edges(), 4))
# test override
......@@ -762,7 +793,7 @@ def test_local_scope():
assert 'ww' not in g.edata
# test initializer1
g = dgl.graph([(0,1), (1,1)])
g = dgl.graph([(0,1), (1,1)], index_dtype=index_dtype)
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
with g.local_scope():
......@@ -781,32 +812,35 @@ def test_local_scope():
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
def test_issue_1088():
@parametrize_dtype
def test_issue_1088(index_dtype):
# This test ensures that message passing on a heterograph with one edge type
# would not crash (GitHub issue #1088).
import dgl.function as fn
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])})
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])}, index_dtype=index_dtype)
g.nodes['U'].data['x'] = F.randn((3, 3))
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
if __name__ == '__main__':
test_isolated_nodes()
test_nx_conversion()
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_apply_nodes()
test_apply_edges()
test_update_routines()
test_recv_0deg()
test_recv_0deg_newfld()
test_update_all_0deg()
test_pull_0deg()
test_send_multigraph()
test_dynamic_addition()
test_repr()
test_group_apply_edges()
test_local_var()
test_local_scope()
test_issue_1088()
# test_isolated_nodes("int32")
# test_nx_conversion()
# test_batch_setter_getter("int32")
# test_batch_recv("int64")
test_apply_edges("int32")
# test_batch_setter_autograd()
# test_batch_send()
# test_batch_recv()
# test_apply_nodes()
# test_apply_edges()
# test_update_routines()
# test_recv_0deg()
# test_recv_0deg_newfld()
# test_update_all_0deg()
# test_pull_0deg()
# test_send_multigraph()
# test_dynamic_addition()
# test_repr()
# test_group_apply_edges()
# test_local_var()
# test_local_scope()
# test_issue_1088()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment