Unverified Commit dc8ca88e authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Refactor] Explicit dtype for HeteroGraph (#1467)



* 111

* 111

* lint

* lint

* lint

* lint

* fix

* lint

* try

* fix

* lint

* lint

* test

* fix

* ttt

* test

* fix

* fix

* fix

* mxnet

* 111

* fix 64bits computation

* pylint

* roll back

* fix

* lint

* fix hetero_from_relations

* remove index_dtype in to_homo and to_hetero

* fix

* fix

* fix

* fix

* remove default

* fix

* lint

* fix

* fix error message

* fix error

* lint

* macro dispatch

* try

* lint

* remove nbits

* error message

* fix

* fix

* lint

* lint

* lint

* fix

* lint

* fix

* fix random walk

* lint

* lint

* fix

* fix

* fix

* lint

* fix

* lint
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent de34e15a
...@@ -100,7 +100,7 @@ pipeline { ...@@ -100,7 +100,7 @@ pipeline {
stage("Lint Check") { stage("Lint Check") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-lint" image "dgllib/dgl-ci-lint"
} }
} }
...@@ -119,7 +119,7 @@ pipeline { ...@@ -119,7 +119,7 @@ pipeline {
stage("CPU Build") { stage("CPU Build") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -135,7 +135,7 @@ pipeline { ...@@ -135,7 +135,7 @@ pipeline {
stage("GPU Build") { stage("GPU Build") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-gpu:conda" image "dgllib/dgl-ci-gpu:conda"
args "-u root" args "-u root"
} }
...@@ -171,7 +171,7 @@ pipeline { ...@@ -171,7 +171,7 @@ pipeline {
stage("C++ CPU") { stage("C++ CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -198,7 +198,7 @@ pipeline { ...@@ -198,7 +198,7 @@ pipeline {
stage("Tensorflow CPU") { stage("Tensorflow CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -239,7 +239,7 @@ pipeline { ...@@ -239,7 +239,7 @@ pipeline {
stage("Torch CPU") { stage("Torch CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -316,7 +316,7 @@ pipeline { ...@@ -316,7 +316,7 @@ pipeline {
stage("MXNet CPU") { stage("MXNet CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
......
...@@ -315,7 +315,10 @@ struct CSRMatrix { ...@@ -315,7 +315,10 @@ struct CSRMatrix {
indptr(parr), indptr(parr),
indices(iarr), indices(iarr),
data(darr), data(darr),
sorted(sorted_flag) {} sorted(sorted_flag) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat) explicit CSRMatrix(const SparseMatrix& spmat)
...@@ -324,7 +327,10 @@ struct CSRMatrix { ...@@ -324,7 +327,10 @@ struct CSRMatrix {
indptr(spmat.indices[0]), indptr(spmat.indices[0]),
indices(spmat.indices[1]), indices(spmat.indices[1]),
data(spmat.indices[2]), data(spmat.indices[2]),
sorted(spmat.flags[0]) {} sorted(spmat.flags[0]) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
...@@ -394,7 +400,10 @@ struct COOMatrix { ...@@ -394,7 +400,10 @@ struct COOMatrix {
col(carr), col(carr),
data(darr), data(darr),
row_sorted(rsorted), row_sorted(rsorted),
col_sorted(csorted) {} col_sorted(csorted) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat) explicit COOMatrix(const SparseMatrix& spmat)
...@@ -404,7 +413,10 @@ struct COOMatrix { ...@@ -404,7 +413,10 @@ struct COOMatrix {
col(spmat.indices[1]), col(spmat.indices[1]),
data(spmat.indices[2]), data(spmat.indices[2]),
row_sorted(spmat.flags[0]), row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {} col_sorted(spmat.flags[1]) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
...@@ -880,6 +892,29 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -880,6 +892,29 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} \ } \
} while (0) } while (0)
/*
* Dispatch according to bits (either int32 or int64):
*
* ATEN_ID_BITS_SWITCH(bits, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/* /*
* Dispatch according to float type (either float32 or float64): * Dispatch according to float type (either float32 or float64):
* *
......
...@@ -669,10 +669,12 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra ...@@ -669,10 +669,12 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
* *
* TODO(minjie): remove the meta_graph argument * TODO(minjie): remove the meta_graph argument
* *
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph of the inputs and result. * \param meta_graph Metagraph of the inputs and result.
* \param component_graphs Input graphs * \param component_graphs Input graphs
* \return One graph that unions all the components * \return One graph that unions all the components
*/ */
template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
...@@ -689,12 +691,14 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -689,12 +691,14 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
* TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes * TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes
* and edge_sizes. * and edge_sizes.
* *
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph. * \param meta_graph Metagraph.
* \param batched_graph Input graph. * \param batched_graph Input graph.
* \param vertex_sizes Number of vertices of each component. * \param vertex_sizes Number of vertices of each component.
* \param edge_sizes Number of vertices of each component. * \param edge_sizes Number of vertices of each component.
* \return A list of graphs representing each disjoint components. * \return A list of graphs representing each disjoint components.
*/ */
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, GraphPtr meta_graph,
HeteroGraphPtr batched_graph, HeteroGraphPtr batched_graph,
...@@ -714,7 +718,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -714,7 +718,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
* This class can be used as arguments and return values of a C API. * This class can be used as arguments and return values of a C API.
*/ */
struct HeteroPickleStates : public runtime::Object { struct HeteroPickleStates : public runtime::Object {
/*! \brief Metagraph. */ /*! \brief Metagraph(64bits ImmutableGraph) */
GraphPtr metagraph; GraphPtr metagraph;
/*! \brief Number of nodes per type */ /*! \brief Number of nodes per type */
......
...@@ -49,6 +49,36 @@ class DGLIdIters { ...@@ -49,6 +49,36 @@ class DGLIdIters {
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
/*!
* \brief int32 version for DGLIdIters
*
*/
class DGLIdIters32 {
public:
/* !\brief default constructor to create an empty range */
DGLIdIters32() {}
/* !\brief constructor with given begin and end */
DGLIdIters32(const int32_t *begin, const int32_t *end) {
this->begin_ = begin;
this->end_ = end;
}
const int32_t *begin() const {
return this->begin_;
}
const int32_t *end() const {
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private:
const int32_t *begin_{nullptr}, *end_{nullptr};
};
/* \brief structure used to represent a list of edges */ /* \brief structure used to represent a list of edges */
typedef struct { typedef struct {
/* \brief the two endpoints and the id of the edge */ /* \brief the two endpoints and the id of the edge */
......
...@@ -17,6 +17,7 @@ namespace sched { ...@@ -17,6 +17,7 @@ namespace sched {
/*! /*!
* \brief Generate degree bucketing schedule * \brief Generate degree bucketing schedule
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param msg_ids The edge id for each message * \param msg_ids The edge id for each message
* \param vids The destination vertex for each message * \param vids The destination vertex for each message
* \param recv_ids The recv nodes (for checking zero degree nodes) * \param recv_ids The recv nodes (for checking zero degree nodes)
...@@ -29,11 +30,13 @@ namespace sched { ...@@ -29,11 +30,13 @@ namespace sched {
* mids: message ids * mids: message ids
* mid_section: number of messages in each bucket (used to split mids) * mid_section: number of messages in each bucket (used to split mids)
*/ */
template <class IdType>
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids, std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids,
const IdArray& recv_ids); const IdArray& recv_ids);
/*! /*!
* \brief Generate degree bucketing schedule for group_apply edge * \brief Generate degree bucketing schedule for group_apply edge
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param uids One end vertex of edge by which edges are grouped * \param uids One end vertex of edge by which edges are grouped
* \param vids The other end vertex of edge * \param vids The other end vertex of edge
* \param eids Edge ids * \param eids Edge ids
...@@ -49,6 +52,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -49,6 +52,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
* sections: number of edges in each degree bucket (used to partition * sections: number of edges in each degree bucket (used to partition
* new_uids, new_vids, and new_eids) * new_uids, new_vids, and new_eids)
*/ */
template <class IdType>
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids,
const IdArray& vids, const IdArray& eids); const IdArray& vids, const IdArray& eids);
......
...@@ -1073,7 +1073,7 @@ def sort_1d(input): ...@@ -1073,7 +1073,7 @@ def sort_1d(input):
""" """
pass pass
def arange(start, stop): def arange(start, stop, dtype):
"""Create a 1D range int64 tensor. """Create a 1D range int64 tensor.
Parameters Parameters
...@@ -1082,6 +1082,8 @@ def arange(start, stop): ...@@ -1082,6 +1082,8 @@ def arange(start, stop):
The range start. The range start.
stop : int stop : int
The range stop. The range stop.
dtype: str
The dtype of result tensor
Returns Returns
------- -------
......
...@@ -338,11 +338,11 @@ def sort_1d(input): ...@@ -338,11 +338,11 @@ def sort_1d(input):
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype='int64')
return val, idx return val, idx
def arange(start, stop): def arange(start, stop, dtype="int64"):
if start >= stop: if start >= stop:
return nd.array([], dtype=np.int64) return nd.array([], dtype=data_type_dict()[dtype])
else: else:
return nd.arange(start, stop, dtype=np.int64) return nd.arange(start, stop, dtype=data_type_dict()[dtype])
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
......
...@@ -194,8 +194,8 @@ def nonzero_1d(input): ...@@ -194,8 +194,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return np.sort(input), np.argsort(input) return np.sort(input), np.argsort(input)
def arange(start, stop): def arange(start, stop, dtype="int64"):
return np.arange(start, stop, dtype=np.int64) return np.arange(start, stop, dtype=getattr(np, dtype))
def rand_shuffle(arr): def rand_shuffle(arr):
copy = np.copy(arr) copy = np.copy(arr)
......
...@@ -154,7 +154,7 @@ def repeat(input, repeats, dim): ...@@ -154,7 +154,7 @@ def repeat(input, repeats, dim):
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1) return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index) return th.index_select(data, 0, row_index.long())
def slice_axis(data, axis, begin, end): def slice_axis(data, axis, begin, end):
return th.narrow(data, axis, begin, end - begin) return th.narrow(data, axis, begin, end - begin)
...@@ -167,7 +167,7 @@ def narrow_row(x, start, stop): ...@@ -167,7 +167,7 @@ def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return data.index_copy(0, row_index, value) return data.index_copy(0, row_index.long(), value)
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
data[row_index] = value data[row_index] = value
...@@ -263,8 +263,8 @@ def nonzero_1d(input): ...@@ -263,8 +263,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return th.sort(input) return th.sort(input)
def arange(start, stop): def arange(start, stop, dtype="int64"):
return th.arange(start, stop, dtype=th.int64) return th.arange(start, stop, dtype=data_type_dict()[dtype])
def rand_shuffle(arr): def rand_shuffle(arr):
idx = th.randperm(len(arr)) idx = th.randperm(len(arr))
......
...@@ -78,8 +78,10 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -78,8 +78,10 @@ def sparse_matrix(data, index, shape, force_format=False):
if fmt != 'coo': if fmt != 'coo':
raise TypeError( raise TypeError(
'Tensorflow backend only supports COO format. But got %s.' % fmt) 'Tensorflow backend only supports COO format. But got %s.' % fmt)
spmat = tf.SparseTensor(indices=tf.transpose( # tf.SparseTensor only supports int64 indexing,
index[1], (1, 0)), values=data, dense_shape=shape) # therefore manually casting to int64 when input in int32
spmat = tf.SparseTensor(indices=tf.cast(tf.transpose(
index[1], (1, 0)), tf.int64), values=data, dense_shape=shape)
return spmat, None return spmat, None
...@@ -372,9 +374,9 @@ def sort_1d(input): ...@@ -372,9 +374,9 @@ def sort_1d(input):
return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64) return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)
def arange(start, stop): def arange(start, stop, dtype="int64"):
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
t = tf.range(start, stop, dtype=tf.int64) t = tf.range(start, stop, dtype=data_type_dict()[dtype])
return t return t
......
...@@ -209,8 +209,8 @@ def metapath_random_walk(hg, etypes, seeds, num_traces): ...@@ -209,8 +209,8 @@ def metapath_random_walk(hg, etypes, seeds, num_traces):
raise ValueError('beginning and ending node type mismatch') raise ValueError('beginning and ending node type mismatch')
if len(seeds) == 0: if len(seeds) == 0:
return [] return []
etype_array = ndarray.array(np.asarray([hg.get_etype_id(et) for et in etypes], dtype='int64')) etype_array = ndarray.array(np.asarray([hg.get_etype_id(et) for et in etypes], dtype="int64"))
seed_array = utils.toindex(seeds).todgltensor() seed_array = utils.toindex(seeds, hg._idtype_str).todgltensor()
traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces) traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces)
return _split_traces(traces) return _split_traces(traces)
......
...@@ -22,7 +22,7 @@ __all__ = [ ...@@ -22,7 +22,7 @@ __all__ = [
] ]
def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True, def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True,
restrict_format='any', **kwargs): restrict_format='any', index_dtype="int64", **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -135,22 +135,22 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -135,22 +135,22 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
u, v = data u, v = data
return create_from_edges( return create_from_edges(
u, v, ntype, etype, ntype, urange, vrange, validate, u, v, ntype, etype, ntype, urange, vrange, validate,
restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list( return create_from_edge_list(
data, ntype, etype, ntype, urange, vrange, validate, data, ntype, etype, ntype, urange, vrange, validate,
restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy( return create_from_scipy(
data, ntype, etype, ntype, restrict_format=restrict_format) data, ntype, etype, ntype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
return create_from_networkx( return create_from_networkx(
data, ntype, etype, restrict_format=restrict_format, **kwargs) data, ntype, etype, restrict_format=restrict_format, index_dtype=index_dtype, **kwargs)
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None, def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None,
validate=True, restrict_format='any', **kwargs): validate=True, restrict_format='any', index_dtype='int64', **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -282,18 +282,19 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -282,18 +282,19 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
if isinstance(data, tuple): if isinstance(data, tuple):
u, v = data u, v = data
return create_from_edges( return create_from_edges(
u, v, utype, etype, vtype, urange, vrange, validate, u, v, utype, etype, vtype, urange, vrange, validate, index_dtype=index_dtype,
restrict_format=restrict_format) restrict_format=restrict_format)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list( return create_from_edge_list(
data, utype, etype, vtype, urange, vrange, validate, data, utype, etype, vtype, urange, vrange, validate, index_dtype=index_dtype,
restrict_format=restrict_format) restrict_format=restrict_format)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy( return create_from_scipy(
data, utype, etype, vtype, restrict_format=restrict_format) data, utype, etype, vtype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite( return create_from_networkx_bipartite(data, utype, etype,
data, utype, etype, vtype, restrict_format=restrict_format, **kwargs) vtype, restrict_format=restrict_format,
index_dtype=index_dtype, **kwargs)
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
...@@ -385,13 +386,18 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -385,13 +386,18 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
ntypes = list(sorted(ntype_set)) ntypes = list(sorted(ntype_set))
else: else:
ntypes = list(sorted(num_nodes_per_type.keys())) ntypes = list(sorted(num_nodes_per_type.keys()))
num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes]) num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes], "int64")
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)} ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
index_dtype = rel_graphs[0]._idtype_str
for rgrh in rel_graphs: for rgrh in rel_graphs:
if rgrh._idtype_str != index_dtype:
raise Exception("Expect relation graphs to be {}, but got {}".format(
index_dtype, rgrh._idtype_str))
stype, etype, dtype = rgrh.canonical_etypes[0] stype, etype, dtype = rgrh.canonical_etypes[0]
meta_edges_src.append(ntype_dict[stype]) meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype]) meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype) etypes.append(etype)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True) metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
# create graph index # create graph index
...@@ -404,7 +410,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -404,7 +410,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
retg._edge_frames[i].update(rgrh._edge_frames[0]) retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg return retg
def heterograph(data_dict, num_nodes_dict=None): def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
"""Create a heterogeneous graph from a dictionary between edge types and edge lists. """Create a heterogeneous graph from a dictionary between edge types and edge lists.
Parameters Parameters
...@@ -484,15 +490,18 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -484,15 +490,18 @@ def heterograph(data_dict, num_nodes_dict=None):
elif srctype == dsttype: elif srctype == dsttype:
rel_graphs.append(graph( rel_graphs.append(graph(
data, srctype, etype, data, srctype, etype,
num_nodes=num_nodes_dict[srctype], validate=False)) num_nodes=num_nodes_dict[srctype], validate=False, index_dtype=index_dtype))
else: else:
rel_graphs.append(bipartite( rel_graphs.append(bipartite(
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False)) num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]),
validate=False, index_dtype=index_dtype))
return hetero_from_relations(rel_graphs, num_nodes_dict) return hetero_from_relations(rel_graphs, num_nodes_dict)
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE,
metagraph=None):
"""Convert the given homogeneous graph to a heterogeneous graph. """Convert the given homogeneous graph to a heterogeneous graph.
The input graph should have only one type of nodes and edges. Each node and edge The input graph should have only one type of nodes and edges. Each node and edge
...@@ -588,6 +597,7 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -588,6 +597,7 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
' type of nodes and edges.') ' type of nodes and edges.')
num_ntypes = len(ntypes) num_ntypes = len(ntypes)
index_dtype = G._idtype_str
ntype_ids = F.asnumpy(G.ndata[ntype_field]) ntype_ids = F.asnumpy(G.ndata[ntype_field])
etype_ids = F.asnumpy(G.edata[etype_field]) etype_ids = F.asnumpy(G.edata[etype_field])
...@@ -641,15 +651,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -641,15 +651,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if stid == dtid: if stid == dtid:
rel_graph = graph( rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
num_nodes=ntype_count[stid], validate=False) num_nodes=ntype_count[stid], validate=False, index_dtype=index_dtype)
else: else:
rel_graph = bipartite( rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid], (src_of_etype,
num_nodes=(ntype_count[stid], ntype_count[dtid]), validate=False) dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
num_nodes=(ntype_count[stid], ntype_count[dtid]),
validate=False, index_dtype=index_dtype)
rel_graphs.append(rel_graph) rel_graphs.append(rel_graph)
hg = hetero_from_relations( hg = hetero_from_relations(rel_graphs,
rel_graphs, {ntype: count for ntype, count in zip(ntypes, ntype_count)}) {ntype: count for ntype, count in zip(
ntypes, ntype_count)})
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)} ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes): for ntid, ntype in enumerate(hg.ntypes):
...@@ -722,7 +735,7 @@ def to_homo(G): ...@@ -722,7 +735,7 @@ def to_homo(G):
num_nodes = G.number_of_nodes(ntype) num_nodes = G.number_of_nodes(ntype)
total_num_nodes += num_nodes total_num_nodes += num_nodes
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu())) ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes)) nids.append(F.arange(0, num_nodes, G._idtype_str))
for etype_id, etype in enumerate(G.canonical_etypes): for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype srctype, _, dsttype = etype
...@@ -731,9 +744,10 @@ def to_homo(G): ...@@ -731,9 +744,10 @@ def to_homo(G):
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)])) srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)])) dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu())) etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges)) eids.append(F.arange(0, num_edges, G._idtype_str))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, validate=False) retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
validate=False, index_dtype=G._idtype_str)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0) retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0) retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0) retg.edata[ETYPE] = F.cat(etype_ids, 0)
...@@ -754,7 +768,7 @@ def to_homo(G): ...@@ -754,7 +768,7 @@ def to_homo(G):
############################################################ ############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True, def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True,
restrict_format="any"): restrict_format="any", index_dtype='int64'):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -786,8 +800,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -786,8 +800,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
------- -------
DGLHeteroGraph DGLHeteroGraph
""" """
u = utils.toindex(u) u = utils.toindex(u, index_dtype)
v = utils.toindex(v) v = utils.toindex(v, index_dtype)
if validate: if validate:
if urange is not None and len(u) > 0 and \ if urange is not None and len(u) > 0 and \
...@@ -817,7 +831,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -817,7 +831,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate=True, restrict_format='any'): validate=True, restrict_format='any', index_dtype='int64'):
"""Internal function to create a heterograph from a list of edge tuples with types. """Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -853,11 +867,11 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, ...@@ -853,11 +867,11 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
u, v = zip(*elist) u, v = zip(*elist)
u = list(u) u = list(u)
v = list(v) v = list(v)
return create_from_edges( return create_from_edges(u, v, utype, etype, vtype, urange, vrange,
u, v, utype, etype, vtype, urange, vrange, validate, restrict_format) validate, restrict_format, index_dtype=index_dtype)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Internal function to create a heterograph from a scipy sparse matrix with types. """Internal function to create a heterograph from a scipy sparse matrix with types.
Parameters Parameters
...@@ -889,16 +903,16 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, ...@@ -889,16 +903,16 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
num_src, num_dst = spmat.shape num_src, num_dst = spmat.shape
num_ntypes = 1 if utype == vtype else 2 num_ntypes = 1 if utype == vtype else 2
if spmat.getformat() == 'coo': if spmat.getformat() == 'coo':
row = utils.toindex(spmat.row) row = utils.toindex(spmat.row.astype(index_dtype), index_dtype)
col = utils.toindex(spmat.col) col = utils.toindex(spmat.col.astype(index_dtype), index_dtype)
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, num_src, num_dst, row, col, restrict_format) num_ntypes, num_src, num_dst, row, col, restrict_format)
else: else:
spmat = spmat.tocsr() spmat = spmat.tocsr()
indptr = utils.toindex(spmat.indptr) indptr = utils.toindex(spmat.indptr.astype(index_dtype), index_dtype)
indices = utils.toindex(spmat.indices) indices = utils.toindex(spmat.indices.astype(index_dtype), index_dtype)
# TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix? # TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix?
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices)))) data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))), index_dtype)
hgidx = heterograph_index.create_unitgraph_from_csr( hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, num_src, num_dst, indptr, indices, data, restrict_format) num_ntypes, num_src, num_dst, indptr, indices, data, restrict_format)
if num_ntypes == 1: if num_ntypes == 1:
...@@ -911,7 +925,7 @@ def create_from_networkx(nx_graph, ...@@ -911,7 +925,7 @@ def create_from_networkx(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Create a heterograph that has only one set of nodes and edges. """Create a heterograph that has only one set of nodes and edges.
Parameters Parameters
...@@ -949,8 +963,8 @@ def create_from_networkx(nx_graph, ...@@ -949,8 +963,8 @@ def create_from_networkx(nx_graph,
if has_edge_id: if has_edge_id:
num_edges = nx_graph.number_of_edges() num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name] eid = attr[edge_id_attr_name]
src[eid] = u src[eid] = u
...@@ -961,11 +975,11 @@ def create_from_networkx(nx_graph, ...@@ -961,11 +975,11 @@ def create_from_networkx(nx_graph,
for e in nx_graph.edges: for e in nx_graph.edges:
src.append(e[0]) src.append(e[0])
dst.append(e[1]) dst.append(e[1])
src = utils.toindex(src) src = utils.toindex(src, index_dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, index_dtype)
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes, g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes,
validate=False, restrict_format=restrict_format) validate=False, restrict_format=restrict_format, index_dtype=index_dtype)
# handle features # handle features
# copy attributes # copy attributes
...@@ -1017,7 +1031,7 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1017,7 +1031,7 @@ def create_from_networkx_bipartite(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Create a heterograph that has one set of source nodes, one set of """Create a heterograph that has one set of source nodes, one set of
destination nodes and one set of edges. destination nodes and one set of edges.
...@@ -1065,8 +1079,8 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1065,8 +1079,8 @@ def create_from_networkx_bipartite(nx_graph,
if has_edge_id: if has_edge_id:
num_edges = nx_graph.number_of_edges() num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name] eid = attr[edge_id_attr_name]
src[eid] = top_map[u] src[eid] = top_map[u]
...@@ -1078,11 +1092,11 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1078,11 +1092,11 @@ def create_from_networkx_bipartite(nx_graph,
if e[0] in top_map: if e[0] in top_map:
src.append(top_map[e[0]]) src.append(top_map[e[0]])
dst.append(bottom_map[e[1]]) dst.append(bottom_map[e[1]])
src = utils.toindex(src) src = utils.toindex(src, index_dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, index_dtype)
g = create_from_edges( g = create_from_edges(src, dst, utype, etype, vtype,
src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes), validate=False,
len(top_nodes), len(bottom_nodes), validate=False, restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
# TODO attributes # TODO attributes
assert node_attrs is None, 'Retrieval of node attributes are not supported yet.' assert node_attrs is None, 'Retrieval of node attributes are not supported yet.'
......
...@@ -793,6 +793,29 @@ class DGLBaseGraph(object): ...@@ -793,6 +793,29 @@ class DGLBaseGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor() return self._graph.out_degrees(v).tousertensor()
@property
def idtype(self):
"""Return the dtype of the graph index
Returns
---------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return getattr(F, self._graph.dtype)
@property
def _idtype_str(self):
"""The dtype of graph index
Returns
-------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return self._graph.dtype
def mutation(func): def mutation(func):
"""A decorator to decorate functions that might change graph structure.""" """A decorator to decorate functions that might change graph structure."""
......
...@@ -864,6 +864,21 @@ class GraphIndex(ObjectBase): ...@@ -864,6 +864,21 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLGraphContext(self) return _CAPI_DGLGraphContext(self)
@property
def dtype(self):
"""Return the index dtype
Returns
----------
str
The dtype of graph index
"""
bits = self.nbits()
if bits == 32:
return "int32"
else:
return "int64"
def copy_to(self, ctx): def copy_to(self, ctx):
"""Copy this immutable graph index to the given device context. """Copy this immutable graph index to the given device context.
......
...@@ -1091,6 +1091,28 @@ class DGLHeteroGraph(object): ...@@ -1091,6 +1091,28 @@ class DGLHeteroGraph(object):
""" """
return self._graph.is_readonly() return self._graph.is_readonly()
@property
def idtype(self):
"""The dtype of graph index
Returns
-------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return getattr(F, self._graph.dtype)
@property
def _idtype_str(self):
"""The dtype of graph index
Returns
-------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return self._graph.dtype
def has_node(self, vid, ntype=None): def has_node(self, vid, ntype=None):
"""Whether the graph has a node with a particular id and type. """Whether the graph has a node with a particular id and type.
...@@ -1148,7 +1170,7 @@ class DGLHeteroGraph(object): ...@@ -1148,7 +1170,7 @@ class DGLHeteroGraph(object):
-------- --------
has_node has_node
""" """
vids = utils.toindex(vids) vids = utils.toindex(vids, self._idtype_str)
rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids) rst = self._graph.has_nodes(self.get_ntype_id(ntype), vids)
return rst.tousertensor() return rst.tousertensor()
...@@ -1214,8 +1236,8 @@ class DGLHeteroGraph(object): ...@@ -1214,8 +1236,8 @@ class DGLHeteroGraph(object):
-------- --------
has_edge_between has_edge_between
""" """
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v) rst = self._graph.has_edges_between(self.get_etype_id(etype), u, v)
return rst.tousertensor() return rst.tousertensor()
...@@ -1293,6 +1315,7 @@ class DGLHeteroGraph(object): ...@@ -1293,6 +1315,7 @@ class DGLHeteroGraph(object):
-------- --------
predecessors predecessors
""" """
check_same_dtype(self._idtype_str, v)
return self._graph.successors(self.get_etype_id(etype), v).tousertensor() return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
def edge_id(self, u, v, force_multi=None, return_array=False, etype=None): def edge_id(self, u, v, force_multi=None, return_array=False, etype=None):
...@@ -1425,8 +1448,10 @@ class DGLHeteroGraph(object): ...@@ -1425,8 +1448,10 @@ class DGLHeteroGraph(object):
-------- --------
edge_id edge_id
""" """
u = utils.toindex(u) check_same_dtype(self._idtype_str, u)
v = utils.toindex(v) check_same_dtype(self._idtype_str, v)
u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v, self._idtype_str)
src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v) src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
if force_multi is not None: if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \ dgl_warning("force_multi will be deprecated, " \
...@@ -1471,7 +1496,8 @@ class DGLHeteroGraph(object): ...@@ -1471,7 +1496,8 @@ class DGLHeteroGraph(object):
>>> g.find_edges([0, 2]) >>> g.find_edges([0, 2])
(tensor([0, 1]), tensor([0, 2])) (tensor([0, 1]), tensor([0, 2]))
""" """
eid = utils.toindex(eid) check_same_dtype(self._idtype_str, eid)
eid = utils.toindex(eid, self._idtype_str)
src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid) src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src.tousertensor(), dst.tousertensor() return src.tousertensor(), dst.tousertensor()
...@@ -1517,7 +1543,8 @@ class DGLHeteroGraph(object): ...@@ -1517,7 +1543,8 @@ class DGLHeteroGraph(object):
>>> g.in_edges([0, 2], form='uv') >>> g.in_edges([0, 2], form='uv')
(tensor([0, 1]), tensor([0, 2])) (tensor([0, 1]), tensor([0, 2]))
""" """
v = utils.toindex(v) check_same_dtype(self._idtype_str, v)
v = utils.toindex(v, self._idtype_str)
src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v) src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
...@@ -1568,7 +1595,8 @@ class DGLHeteroGraph(object): ...@@ -1568,7 +1595,8 @@ class DGLHeteroGraph(object):
>>> g.out_edges([0, 1], form='uv') >>> g.out_edges([0, 1], form='uv')
(tensor([0, 1, 1]), tensor([0, 1, 2])) (tensor([0, 1, 1]), tensor([0, 1, 2]))
""" """
u = utils.toindex(u) check_same_dtype(self._idtype_str, u)
u = utils.toindex(u, self._idtype_str)
src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u) src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
if form == 'all': if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor()) return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
...@@ -1710,12 +1738,13 @@ class DGLHeteroGraph(object): ...@@ -1710,12 +1738,13 @@ class DGLHeteroGraph(object):
-------- --------
in_degree in_degree
""" """
check_same_dtype(self._idtype_str, v)
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
_, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
if is_all(v): if is_all(v):
v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid))) v = utils.toindex(slice(0, self._graph.number_of_nodes(dtid)), self._idtype_str)
else: else:
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
return self._graph.in_degrees(etid, v).tousertensor() return self._graph.in_degrees(etid, v).tousertensor()
def out_degree(self, u, etype=None): def out_degree(self, u, etype=None):
...@@ -1795,12 +1824,13 @@ class DGLHeteroGraph(object): ...@@ -1795,12 +1824,13 @@ class DGLHeteroGraph(object):
-------- --------
out_degree out_degree
""" """
check_same_dtype(self._idtype_str, u)
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, _ = self._graph.metagraph.find_edge(etid) stid, _ = self._graph.metagraph.find_edge(etid)
if is_all(u): if is_all(u):
u = utils.toindex(slice(0, self._graph.number_of_nodes(stid))) u = utils.toindex(slice(0, self._graph.number_of_nodes(stid)), self._idtype_str)
else: else:
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
return self._graph.out_degrees(etid, u).tousertensor() return self._graph.out_degrees(etid, u).tousertensor()
def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges): def _create_hetero_subgraph(self, sgi, induced_nodes, induced_edges):
...@@ -1893,7 +1923,9 @@ class DGLHeteroGraph(object): ...@@ -1893,7 +1923,9 @@ class DGLHeteroGraph(object):
-------- --------
edge_subgraph edge_subgraph
""" """
induced_nodes = [utils.toindex(nodes.get(ntype, [])) for ntype in self.ntypes] check_same_dtype(self._idtype_str, nodes)
induced_nodes = [utils.toindex(nodes.get(ntype, []), self._idtype_str)
for ntype in self.ntypes]
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
induced_edges = sgi.induced_edges induced_edges = sgi.induced_edges
...@@ -1974,9 +2006,10 @@ class DGLHeteroGraph(object): ...@@ -1974,9 +2006,10 @@ class DGLHeteroGraph(object):
-------- --------
subgraph subgraph
""" """
check_idtype_dict(self._idtype_str, edges)
edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()} edges = {self.to_canonical_etype(etype): e for etype, e in edges.items()}
induced_edges = [ induced_edges = [
utils.toindex(edges.get(canonical_etype, [])) utils.toindex(edges.get(canonical_etype, []), self._idtype_str)
for canonical_etype in self.canonical_etypes] for canonical_etype in self.canonical_etypes]
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes)
induced_nodes = sgi.induced_nodes induced_nodes = sgi.induced_nodes
...@@ -2057,9 +2090,11 @@ class DGLHeteroGraph(object): ...@@ -2057,9 +2090,11 @@ class DGLHeteroGraph(object):
edge_frames.append(self._edge_frames[i]) edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True) metagraph = graph_index.from_edge_list(meta_edges, True)
# num_nodes_per_type doesn't need to be int32
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type)) metagraph, rel_graphs, utils.toindex(num_nodes_per_type, "int64"))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes,
node_frames, edge_frames)
return hg return hg
def edge_type_subgraph(self, etypes): def edge_type_subgraph(self, etypes):
...@@ -2117,7 +2152,8 @@ class DGLHeteroGraph(object): ...@@ -2117,7 +2152,8 @@ class DGLHeteroGraph(object):
node_type_subgraph node_type_subgraph
""" """
etype_ids = [self.get_etype_id(etype) for etype in etypes] etype_ids = [self.get_etype_id(etype) for etype in etypes]
meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids)) # meta graph is homograph, still using int64
meta_src, meta_dst, _ = self._graph.metagraph.find_edges(utils.toindex(etype_ids, "int64"))
rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids] rel_graphs = [self._graph.get_relation_graph(i) for i in etype_ids]
meta_src = meta_src.tonumpy() meta_src = meta_src.tonumpy()
meta_dst = meta_dst.tonumpy() meta_dst = meta_dst.tonumpy()
...@@ -2131,8 +2167,9 @@ class DGLHeteroGraph(object): ...@@ -2131,8 +2167,9 @@ class DGLHeteroGraph(object):
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes] num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True) metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
# num_nodes_per_type should be int64
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type)) metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64"))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
return hg return hg
...@@ -2433,7 +2470,7 @@ class DGLHeteroGraph(object): ...@@ -2433,7 +2470,7 @@ class DGLHeteroGraph(object):
if is_all(u): if is_all(u):
num_nodes = self._graph.number_of_nodes(ntid) num_nodes = self._graph.number_of_nodes(ntid)
else: else:
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
num_nodes = len(u) num_nodes = len(u)
for key, val in data.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
...@@ -2467,7 +2504,7 @@ class DGLHeteroGraph(object): ...@@ -2467,7 +2504,7 @@ class DGLHeteroGraph(object):
if is_all(u): if is_all(u):
return dict(self._node_frames[ntid]) return dict(self._node_frames[ntid])
else: else:
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
return self._node_frames[ntid].select_rows(u) return self._node_frames[ntid].select_rows(u)
def _pop_n_repr(self, ntid, key): def _pop_n_repr(self, ntid, key):
...@@ -2520,12 +2557,12 @@ class DGLHeteroGraph(object): ...@@ -2520,12 +2557,12 @@ class DGLHeteroGraph(object):
eid = ALL eid = ALL
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etid, u, v) _, _, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
# sanity check # sanity check
if not utils.is_dict_like(data): if not utils.is_dict_like(data):
...@@ -2535,7 +2572,7 @@ class DGLHeteroGraph(object): ...@@ -2535,7 +2572,7 @@ class DGLHeteroGraph(object):
if is_all(eid): if is_all(eid):
num_edges = self._graph.number_of_edges(etid) num_edges = self._graph.number_of_edges(etid)
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid, self._idtype_str)
num_edges = len(eid) num_edges = len(eid)
for key, val in data.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
...@@ -2572,17 +2609,17 @@ class DGLHeteroGraph(object): ...@@ -2572,17 +2609,17 @@ class DGLHeteroGraph(object):
eid = ALL eid = ALL
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(etid, u, v) _, _, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
if is_all(eid): if is_all(eid):
return dict(self._edge_frames[etid]) return dict(self._edge_frames[etid])
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid, self._idtype_str)
return self._edge_frames[etid].select_rows(eid) return self._edge_frames[etid].select_rows(eid)
def _pop_e_repr(self, etid, key): def _pop_e_repr(self, etid, key):
...@@ -2640,11 +2677,12 @@ class DGLHeteroGraph(object): ...@@ -2640,11 +2677,12 @@ class DGLHeteroGraph(object):
-------- --------
apply_edges apply_edges
""" """
check_same_dtype(self._idtype_str, v)
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
if is_all(v): if is_all(v):
v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype))) v_ntype = utils.toindex(slice(0, self.number_of_nodes(ntype)), self._idtype_str)
else: else:
v_ntype = utils.toindex(v) v_ntype = utils.toindex(v, self._idtype_str)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid], scheduler.schedule_apply_nodes(v_ntype, func, self._node_frames[ntid],
inplace=inplace, ntype=self._ntypes[ntid]) inplace=inplace, ntype=self._ntypes[ntid])
...@@ -2687,19 +2725,20 @@ class DGLHeteroGraph(object): ...@@ -2687,19 +2725,20 @@ class DGLHeteroGraph(object):
apply_nodes apply_nodes
group_apply_edges group_apply_edges
""" """
check_same_dtype(self._idtype_str, edges)
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid') u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype))) eid = utils.toindex(slice(0, self.number_of_edges(etype)), self._idtype_str)
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog: with ir.prog() as prog:
...@@ -2748,6 +2787,7 @@ class DGLHeteroGraph(object): ...@@ -2748,6 +2787,7 @@ class DGLHeteroGraph(object):
-------- --------
apply_edges apply_edges
""" """
check_same_dtype(self._idtype_str, edges)
if group_by not in ('src', 'dst'): if group_by not in ('src', 'dst'):
raise DGLError("Group_by should be either src or dst") raise DGLError("Group_by should be either src or dst")
...@@ -2755,15 +2795,15 @@ class DGLHeteroGraph(object): ...@@ -2755,15 +2795,15 @@ class DGLHeteroGraph(object):
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid') u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype))) eid = utils.toindex(slice(0, self.number_of_edges(etype)), self._idtype_str)
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
with ir.prog() as prog: with ir.prog() as prog:
...@@ -2833,21 +2873,22 @@ class DGLHeteroGraph(object): ...@@ -2833,21 +2873,22 @@ class DGLHeteroGraph(object):
>>> # Send the feature of source nodes along multiple edges specified by their end points >>> # Send the feature of source nodes along multiple edges specified by their end points
>>> g.send(([0, 1], [1, 2]), fn.copy_src('h', 'm')) >>> g.send(([0, 1], [1, 2]), fn.copy_src('h', 'm'))
""" """
check_same_dtype(self._idtype_str, edges)
assert message_func is not None assert message_func is not None
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid))) eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)), self._idtype_str)
u, v, _ = self._graph.edges(etid, 'eid') u, v, _ = self._graph.edges(etid, 'eid')
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
if len(eid) == 0: if len(eid) == 0:
...@@ -2932,13 +2973,14 @@ class DGLHeteroGraph(object): ...@@ -2932,13 +2973,14 @@ class DGLHeteroGraph(object):
[0.], [0.],
[1.]]) [1.]])
""" """
check_same_dtype(self._idtype_str, v)
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(v): if is_all(v):
v = F.arange(0, self.number_of_nodes(dtid)) v = F.arange(0, self.number_of_nodes(dtid), self._idtype_str)
elif isinstance(v, int): elif isinstance(v, int):
v = [v] v = [v]
v = utils.toindex(v) v = utils.toindex(v, dtype=self._idtype_str)
if len(v) == 0: if len(v) == 0:
# no vertex to be triggered. # no vertex to be triggered.
return return
...@@ -3005,14 +3047,15 @@ class DGLHeteroGraph(object): ...@@ -3005,14 +3047,15 @@ class DGLHeteroGraph(object):
tensor([[0.], tensor([[0.],
[2.]]) [2.]])
""" """
check_same_dtype(self._idtype_str, v)
# infer receive node type # infer receive node type
ntype = infer_ntype_from_dict(self, reducer_dict) ntype = infer_ntype_from_dict(self, reducer_dict)
ntid = self.get_ntype_id_from_dst(ntype) ntid = self.get_ntype_id_from_dst(ntype)
if is_all(v): if is_all(v):
v = F.arange(0, self.number_of_nodes(ntid)) v = F.arange(0, self.number_of_nodes(ntid), self._idtype_str)
elif isinstance(v, int): elif isinstance(v, int):
v = [v] v = [v]
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
if len(v) == 0: if len(v) == 0:
return return
# TODO(minjie): currently loop over each edge type and reuse the old schedule. # TODO(minjie): currently loop over each edge type and reuse the old schedule.
...@@ -3122,12 +3165,12 @@ class DGLHeteroGraph(object): ...@@ -3122,12 +3165,12 @@ class DGLHeteroGraph(object):
if isinstance(edges, tuple): if isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
if len(u) == 0: if len(u) == 0:
...@@ -3235,12 +3278,12 @@ class DGLHeteroGraph(object): ...@@ -3235,12 +3278,12 @@ class DGLHeteroGraph(object):
edges, mfunc, rfunc, afunc = args edges, mfunc, rfunc, afunc = args
if isinstance(edges, tuple): if isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
all_vs.append(v) all_vs.append(v)
if len(u) == 0: if len(u) == 0:
...@@ -3330,11 +3373,12 @@ class DGLHeteroGraph(object): ...@@ -3330,11 +3373,12 @@ class DGLHeteroGraph(object):
[1.], [1.],
[1.]]) [1.]])
""" """
check_same_dtype(self._idtype_str, v)
# only one type of edges # only one type of edges
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
...@@ -3403,7 +3447,8 @@ class DGLHeteroGraph(object): ...@@ -3403,7 +3447,8 @@ class DGLHeteroGraph(object):
tensor([[0.], tensor([[0.],
[3.]]) [3.]])
""" """
v = utils.toindex(v) check_same_dtype(self._idtype_str, v)
v = utils.toindex(v, self._idtype_str)
if len(v) == 0: if len(v) == 0:
return return
# infer receive node type # infer receive node type
...@@ -3495,11 +3540,12 @@ class DGLHeteroGraph(object): ...@@ -3495,11 +3540,12 @@ class DGLHeteroGraph(object):
[0.], [0.],
[0.]]) [0.]])
""" """
check_same_dtype(self._idtype_str, u)
# only one type of edges # only one type of edges
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
if len(u) == 0: if len(u) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
...@@ -3870,11 +3916,12 @@ class DGLHeteroGraph(object): ...@@ -3870,11 +3916,12 @@ class DGLHeteroGraph(object):
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user') >>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
tensor([1, 2]) tensor([1, 2])
""" """
check_same_dtype(self._idtype_str, nodes)
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
if is_all(nodes): if is_all(nodes):
v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid))) v = utils.toindex(slice(0, self._graph.number_of_nodes(ntid)), self._idtype_str)
else: else:
v = utils.toindex(nodes) v = utils.toindex(nodes, self._idtype_str)
n_repr = self._get_n_repr(ntid, v) n_repr = self._get_n_repr(ntid, v)
nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid]) nbatch = NodeBatch(v, n_repr, ntype=self.ntypes[ntid])
...@@ -3920,19 +3967,20 @@ class DGLHeteroGraph(object): ...@@ -3920,19 +3967,20 @@ class DGLHeteroGraph(object):
>>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows') >>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
tensor([1, 2]) tensor([1, 2])
""" """
check_same_dtype(self._idtype_str, edges)
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etid, 'eid') u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid))) eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)), self._idtype_str)
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u, self._idtype_str)
v = utils.toindex(v) v = utils.toindex(v, self._idtype_str)
# Rewrite u, v to handle edge broadcasting and multigraph. # Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(etid, u, v) u, v, eid = self._graph.edge_ids(etid, u, v)
else: else:
eid = utils.toindex(edges) eid = utils.toindex(edges, self._idtype_str)
u, v, _ = self._graph.find_edges(etid, eid) u, v, _ = self._graph.find_edges(etid, eid)
src_data = self._get_n_repr(stid, u) src_data = self._get_n_repr(stid, u)
...@@ -4102,6 +4150,54 @@ class DGLHeteroGraph(object): ...@@ -4102,6 +4150,54 @@ class DGLHeteroGraph(object):
"""Return if the graph is homogeneous.""" """Return if the graph is homogeneous."""
return len(self.ntypes) == 1 and len(self.etypes) == 1 return len(self.ntypes) == 1 and len(self.etypes) == 1
def long(self):
"""Return a heterograph object use int64 as index dtype,
with the ndata and edata as the original object
Returns
-------
DGLHeteroGraph
The graph object
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
>>> index_dtype='int32')
>>> g_long = g.long() # Convert g to int64 indexed, not changing the original `g`
See Also
--------
int
"""
return DGLHeteroGraph(self._graph.asbits(64), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
def int(self):
"""Return a heterograph object use int32 as index dtype,
with the ndata and edata as the original object
Returns
-------
DGLHeteroGraph
The graph object
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
>>> index_dtype='int64')
>>> g_int = g.int() # Convert g to int32 indexed, not changing the original `g`
See Also
--------
long
"""
return DGLHeteroGraph(self._graph.asbits(32), self.ntypes, self.etypes,
self._node_frames,
self._edge_frames)
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
...@@ -4425,3 +4521,18 @@ class AdaptedHeteroGraph(GraphAdapter): ...@@ -4425,3 +4521,18 @@ class AdaptedHeteroGraph(GraphAdapter):
def canonical_etype(self): def canonical_etype(self):
"""Canonical edge type.""" """Canonical edge type."""
return self.graph.canonical_etypes[self.etid] return self.graph.canonical_etypes[self.etid]
def check_same_dtype(graph_dtype, tensor):
"""check whether tensor's dtype is consistent with graph's dtype"""
if F.is_tensor(tensor):
if graph_dtype != F.reverse_data_type_dict[F.dtype(tensor)]:
raise utils.InconsistentDtypeException(
"Expect the input tensor to be the same as the graph index dtype({}), but got {}"
.format(graph_dtype, F.reverse_data_type_dict[F.dtype(tensor)]))
def check_idtype_dict(graph_dtype, tensor_dict):
"""check whether the dtypes of tensors in dict are consistent with graph's dtype"""
for _, v in tensor_dict.items():
check_same_dtype(graph_dtype, v)
...@@ -129,6 +129,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -129,6 +129,7 @@ class HeteroGraphIndex(ObjectBase):
_CAPI_DGLHeteroClear(self) _CAPI_DGLHeteroClear(self)
self._cache.clear() self._cache.clear()
@property
def dtype(self): def dtype(self):
"""Return the data type of this graph index. """Return the data type of this graph index.
...@@ -149,16 +150,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -149,16 +150,6 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroContext(self) return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype): def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the unitgraph graph. """Return the number of integer bits needed to represent the unitgraph graph.
...@@ -298,7 +289,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -298,7 +289,7 @@ class HeteroGraphIndex(ObjectBase):
0-1 array indicating existence 0-1 array indicating existence
""" """
vid_array = vids.todgltensor() vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array)) return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array), self.dtype)
def has_edge_between(self, etype, u, v): def has_edge_between(self, etype, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
...@@ -339,7 +330,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -339,7 +330,7 @@ class HeteroGraphIndex(ObjectBase):
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween( return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array)) self, int(etype), u_array, v_array), self.dtype)
def predecessors(self, etype, v): def predecessors(self, etype, v):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -359,7 +350,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -359,7 +350,7 @@ class HeteroGraphIndex(ObjectBase):
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLHeteroPredecessors( return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v))) self, int(etype), int(v)), self.dtype)
def successors(self, etype, v): def successors(self, etype, v):
"""Return the successors of the node. """Return the successors of the node.
...@@ -379,7 +370,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -379,7 +370,7 @@ class HeteroGraphIndex(ObjectBase):
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLHeteroSuccessors( return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v))) self, int(etype), int(v)), self.dtype)
def edge_id(self, etype, u, v): def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -399,7 +390,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -399,7 +390,7 @@ class HeteroGraphIndex(ObjectBase):
The edge id array. The edge id array.
""" """
return utils.toindex(_CAPI_DGLHeteroEdgeId( return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v))) self, int(etype), int(u), int(v)), self.dtype)
def edge_ids(self, etype, u, v): def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs. """Return a triplet of arrays that contains the edge IDs.
...@@ -426,9 +417,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -426,9 +417,9 @@ class HeteroGraphIndex(ObjectBase):
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array) edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
...@@ -454,9 +445,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -454,9 +445,9 @@ class HeteroGraphIndex(ObjectBase):
eid_array = eid.todgltensor() eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array) edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
...@@ -486,9 +477,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -486,9 +477,9 @@ class HeteroGraphIndex(ObjectBase):
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array) edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
def out_edges(self, etype, v): def out_edges(self, etype, v):
...@@ -517,9 +508,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -517,9 +508,9 @@ class HeteroGraphIndex(ObjectBase):
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array) edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache='_cache', prefix='edges')
...@@ -552,9 +543,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -552,9 +543,9 @@ class HeteroGraphIndex(ObjectBase):
src = edge_array(0) src = edge_array(0)
dst = edge_array(1) dst = edge_array(1)
eid = edge_array(2) eid = edge_array(2)
src = utils.toindex(src) src = utils.toindex(src, self.dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, self.dtype)
eid = utils.toindex(eid) eid = utils.toindex(eid, self.dtype)
return src, dst, eid return src, dst, eid
def in_degree(self, etype, v): def in_degree(self, etype, v):
...@@ -594,7 +585,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -594,7 +585,7 @@ class HeteroGraphIndex(ObjectBase):
The in degree array. The in degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array)) return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array), self.dtype)
def out_degree(self, etype, v): def out_degree(self, etype, v):
"""Return the out degree of the node. """Return the out degree of the node.
...@@ -633,7 +624,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -633,7 +624,7 @@ class HeteroGraphIndex(ObjectBase):
The out degree array. The out degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array)) return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array), self.dtype)
def adjacency_matrix(self, etype, transpose, ctx): def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -672,18 +663,20 @@ class HeteroGraphIndex(ObjectBase): ...@@ -672,18 +663,20 @@ class HeteroGraphIndex(ObjectBase):
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx) indptr = F.copy_to(utils.toindex(rst(0), self.dtype).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx) indices = F.copy_to(utils.toindex(rst(1), self.dtype).tousertensor(), ctx)
shuffle = utils.toindex(rst(2)) shuffle = utils.toindex(rst(2), self.dtype)
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0] spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle return spmat, shuffle
elif fmt == "coo": elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx) idx = F.copy_to(utils.toindex(rst(0), self.dtype).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz)) idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx) dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols)) adj, shuffle_idx = F.sparse_matrix(
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(
shuffle_idx, self.dtype) if shuffle_idx is not None else None
return adj, shuffle_idx return adj, shuffle_idx
else: else:
raise Exception("unknown format") raise Exception("unknown format")
...@@ -732,12 +725,12 @@ class HeteroGraphIndex(ObjectBase): ...@@ -732,12 +725,12 @@ class HeteroGraphIndex(ObjectBase):
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0), self.dtype).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1), self.dtype).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo': elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy() idx = utils.toindex(rst(0), self.dtype).tonumpy()
row, col = np.reshape(idx, (2, nnz)) row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row) data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols)) return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
...@@ -906,7 +899,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -906,7 +899,8 @@ class HeteroGraphIndex(ObjectBase):
order = csr(2) order = csr(2)
rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr") rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr")
rev_order = rev_csr(2) rev_order = rev_csr(2)
return utils.toindex(order), utils.toindex(rev_order) return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype)
@register_object('graph.HeteroSubgraph') @register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
...@@ -933,7 +927,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -933,7 +927,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced nodes Induced nodes
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self) ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data, self.graph.dtype) for v in ret]
@property @property
def induced_edges(self): def induced_edges(self):
...@@ -946,7 +940,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -946,7 +940,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced edges Induced edges
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data, self.graph.dtype) for v in ret]
################################################################# #################################################################
# Creators # Creators
......
...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBinaryOpReduce( _CAPI_DGLKernelBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce( ...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardLhsBinaryOpReduce( _CAPI_DGLKernelBackwardLhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce( ...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardRhsBinaryOpReduce( _CAPI_DGLKernelBackwardRhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target, ...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = nd.NULL X_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelCopyReduce( _CAPI_DGLKernelCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, X_rows, out_rows) X, out, X_rows, out_rows)
...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target, ...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target,
The rows written to output tensor. The rows written to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = nd.NULL X_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardCopyReduce( _CAPI_DGLKernelBackwardCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, grad_out, grad_X, X, out, grad_out, grad_X,
......
...@@ -177,4 +177,7 @@ _init_api("dgl.ndarray") ...@@ -177,4 +177,7 @@ _init_api("dgl.ndarray")
# An array representing null (no value) that can be safely converted to # An array representing null (no value) that can be safely converted to
# other backend tensors. # other backend tensors.
NULL = array(_np.array([], dtype=_np.int64)) NULL = {
"int64": array(_np.array([], dtype=_np.int64)),
"int32": array(_np.array([], dtype=_np.int32))
}
...@@ -70,7 +70,7 @@ def choice(a, size, replace=True, prob=None): # pylint: disable=invalid-name ...@@ -70,7 +70,7 @@ def choice(a, size, replace=True, prob=None): # pylint: disable=invalid-name
population = a population = a
if prob is None: if prob is None:
prob = nd.NULL prob = nd.NULL["int64"]
else: else:
prob = F.zerocopy_to_dgl_ndarray(prob) prob = F.zerocopy_to_dgl_ndarray(prob)
......
"""Module for degree bucketing schedulers.""" """Module for degree bucketing schedulers."""
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -118,11 +119,12 @@ def _process_node_buckets(buckets): ...@@ -118,11 +119,12 @@ def _process_node_buckets(buckets):
The zero-degree nodes The zero-degree nodes
""" """
# get back results # get back results
degs = utils.toindex(buckets(0)) dtype = buckets(0).dtype
v = utils.toindex(buckets(1)) degs = utils.toindex(buckets(0), dtype)
v = utils.toindex(buckets(1), dtype)
# XXX: convert directly from ndarary to python list? # XXX: convert directly from ndarary to python list?
v_section = buckets(2).asnumpy().tolist() v_section = buckets(2).asnumpy().tolist()
msg_ids = utils.toindex(buckets(3)) msg_ids = utils.toindex(buckets(3), dtype)
msg_section = buckets(4).asnumpy().tolist() msg_section = buckets(4).asnumpy().tolist()
# split buckets # split buckets
...@@ -131,8 +133,8 @@ def _process_node_buckets(buckets): ...@@ -131,8 +133,8 @@ def _process_node_buckets(buckets):
msg_ids = F.split(msg_ids, msg_section, 0) msg_ids = F.split(msg_ids, msg_section, 0)
# convert to utils.Index # convert to utils.Index
dsts = [utils.toindex(dst) for dst in dsts] dsts = [utils.toindex(dst, dtype) for dst in dsts]
msg_ids = [utils.toindex(msg_id) for msg_id in msg_ids] msg_ids = [utils.toindex(msg_id, dtype) for msg_id in msg_ids]
# handle zero deg # handle zero deg
degs = degs.tonumpy() degs = degs.tonumpy()
...@@ -266,17 +268,18 @@ def _process_edge_buckets(buckets): ...@@ -266,17 +268,18 @@ def _process_edge_buckets(buckets):
A list of edge id buckets A list of edge id buckets
""" """
# get back results # get back results
dtype = buckets(0).dtype
degs = buckets(0).asnumpy() degs = buckets(0).asnumpy()
uids = utils.toindex(buckets(1)) uids = utils.toindex(buckets(1), dtype)
vids = utils.toindex(buckets(2)) vids = utils.toindex(buckets(2), dtype)
eids = utils.toindex(buckets(3)) eids = utils.toindex(buckets(3), dtype)
# XXX: convert directly from ndarary to python list? # XXX: convert directly from ndarary to python list?
sections = buckets(4).asnumpy().tolist() sections = buckets(4).asnumpy().tolist()
# split buckets and convert to index # split buckets and convert to index
def split(to_split): def split(to_split):
res = F.split(to_split.tousertensor(), sections, 0) res = F.split(to_split.tousertensor(), sections, 0)
return map(utils.toindex, res) return map(partial(utils.toindex, dtype=dtype), res)
uids = split(uids) uids = split(uids)
vids = split(vids) vids = split(vids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment