"tools/python/git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "c9bdb9b2da449d6ff37e1345648bbf2a32655c34"
Unverified Commit dc8ca88e authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Refactor] Explicit dtype for HeteroGraph (#1467)



* 111

* 111

* lint

* lint

* lint

* lint

* fix

* lint

* try

* fix

* lint

* lint

* test

* fix

* ttt

* test

* fix

* fix

* fix

* mxnet

* 111

* fix 64bits computation

* pylint

* roll back

* fix

* lint

* fix hetero_from_relations

* remove index_dtype in to_homo and to_hetero

* fix

* fix

* fix

* fix

* remove default

* fix

* lint

* fix

* fix error message

* fix error

* lint

* macro dispatch

* try

* lint

* remove nbits

* error message

* fix

* fix

* lint

* lint

* lint

* fix

* lint

* fix

* fix random walk

* lint

* lint

* fix

* fix

* fix

* lint

* fix

* lint
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent de34e15a
...@@ -100,7 +100,7 @@ pipeline { ...@@ -100,7 +100,7 @@ pipeline {
stage("Lint Check") { stage("Lint Check") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-lint" image "dgllib/dgl-ci-lint"
} }
} }
...@@ -119,7 +119,7 @@ pipeline { ...@@ -119,7 +119,7 @@ pipeline {
stage("CPU Build") { stage("CPU Build") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -135,7 +135,7 @@ pipeline { ...@@ -135,7 +135,7 @@ pipeline {
stage("GPU Build") { stage("GPU Build") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-gpu:conda" image "dgllib/dgl-ci-gpu:conda"
args "-u root" args "-u root"
} }
...@@ -171,7 +171,7 @@ pipeline { ...@@ -171,7 +171,7 @@ pipeline {
stage("C++ CPU") { stage("C++ CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -198,7 +198,7 @@ pipeline { ...@@ -198,7 +198,7 @@ pipeline {
stage("Tensorflow CPU") { stage("Tensorflow CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -239,7 +239,7 @@ pipeline { ...@@ -239,7 +239,7 @@ pipeline {
stage("Torch CPU") { stage("Torch CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
...@@ -316,7 +316,7 @@ pipeline { ...@@ -316,7 +316,7 @@ pipeline {
stage("MXNet CPU") { stage("MXNet CPU") {
agent { agent {
docker { docker {
label "linux-cpu-node" label "linux-c52x-node"
image "dgllib/dgl-ci-cpu:conda" image "dgllib/dgl-ci-cpu:conda"
} }
} }
......
...@@ -315,7 +315,10 @@ struct CSRMatrix { ...@@ -315,7 +315,10 @@ struct CSRMatrix {
indptr(parr), indptr(parr),
indices(iarr), indices(iarr),
data(darr), data(darr),
sorted(sorted_flag) {} sorted(sorted_flag) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat) explicit CSRMatrix(const SparseMatrix& spmat)
...@@ -324,7 +327,10 @@ struct CSRMatrix { ...@@ -324,7 +327,10 @@ struct CSRMatrix {
indptr(spmat.indices[0]), indptr(spmat.indices[0]),
indices(spmat.indices[1]), indices(spmat.indices[1]),
data(spmat.indices[2]), data(spmat.indices[2]),
sorted(spmat.flags[0]) {} sorted(spmat.flags[0]) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
...@@ -394,7 +400,10 @@ struct COOMatrix { ...@@ -394,7 +400,10 @@ struct COOMatrix {
col(carr), col(carr),
data(darr), data(darr),
row_sorted(rsorted), row_sorted(rsorted),
col_sorted(csorted) {} col_sorted(csorted) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat) explicit COOMatrix(const SparseMatrix& spmat)
...@@ -404,7 +413,10 @@ struct COOMatrix { ...@@ -404,7 +413,10 @@ struct COOMatrix {
col(spmat.indices[1]), col(spmat.indices[1]),
data(spmat.indices[2]), data(spmat.indices[2]),
row_sorted(spmat.flags[0]), row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {} col_sorted(spmat.flags[1]) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
...@@ -880,6 +892,29 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -880,6 +892,29 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} \ } \
} while (0) } while (0)
/*
* Dispatch according to bits (either int32 or int64):
*
* ATEN_ID_BITS_SWITCH(bits, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/* /*
* Dispatch according to float type (either float32 or float64): * Dispatch according to float type (either float32 or float64):
* *
......
...@@ -669,10 +669,12 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra ...@@ -669,10 +669,12 @@ HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArra
* *
* TODO(minjie): remove the meta_graph argument * TODO(minjie): remove the meta_graph argument
* *
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph of the inputs and result. * \param meta_graph Metagraph of the inputs and result.
* \param component_graphs Input graphs * \param component_graphs Input graphs
* \return One graph that unions all the components * \return One graph that unions all the components
*/ */
template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
...@@ -689,12 +691,14 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -689,12 +691,14 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
* TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes * TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes
* and edge_sizes. * and edge_sizes.
* *
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph. * \param meta_graph Metagraph.
* \param batched_graph Input graph. * \param batched_graph Input graph.
* \param vertex_sizes Number of vertices of each component. * \param vertex_sizes Number of vertices of each component.
* \param edge_sizes Number of vertices of each component. * \param edge_sizes Number of vertices of each component.
* \return A list of graphs representing each disjoint components. * \return A list of graphs representing each disjoint components.
*/ */
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, GraphPtr meta_graph,
HeteroGraphPtr batched_graph, HeteroGraphPtr batched_graph,
...@@ -714,7 +718,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -714,7 +718,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
* This class can be used as arguments and return values of a C API. * This class can be used as arguments and return values of a C API.
*/ */
struct HeteroPickleStates : public runtime::Object { struct HeteroPickleStates : public runtime::Object {
/*! \brief Metagraph. */ /*! \brief Metagraph(64bits ImmutableGraph) */
GraphPtr metagraph; GraphPtr metagraph;
/*! \brief Number of nodes per type */ /*! \brief Number of nodes per type */
......
...@@ -49,6 +49,36 @@ class DGLIdIters { ...@@ -49,6 +49,36 @@ class DGLIdIters {
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
/*!
* \brief int32 version for DGLIdIters
*
*/
class DGLIdIters32 {
public:
/* !\brief default constructor to create an empty range */
DGLIdIters32() {}
/* !\brief constructor with given begin and end */
DGLIdIters32(const int32_t *begin, const int32_t *end) {
this->begin_ = begin;
this->end_ = end;
}
const int32_t *begin() const {
return this->begin_;
}
const int32_t *end() const {
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private:
const int32_t *begin_{nullptr}, *end_{nullptr};
};
/* \brief structure used to represent a list of edges */ /* \brief structure used to represent a list of edges */
typedef struct { typedef struct {
/* \brief the two endpoints and the id of the edge */ /* \brief the two endpoints and the id of the edge */
......
...@@ -17,6 +17,7 @@ namespace sched { ...@@ -17,6 +17,7 @@ namespace sched {
/*! /*!
* \brief Generate degree bucketing schedule * \brief Generate degree bucketing schedule
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param msg_ids The edge id for each message * \param msg_ids The edge id for each message
* \param vids The destination vertex for each message * \param vids The destination vertex for each message
* \param recv_ids The recv nodes (for checking zero degree nodes) * \param recv_ids The recv nodes (for checking zero degree nodes)
...@@ -29,11 +30,13 @@ namespace sched { ...@@ -29,11 +30,13 @@ namespace sched {
* mids: message ids * mids: message ids
* mid_section: number of messages in each bucket (used to split mids) * mid_section: number of messages in each bucket (used to split mids)
*/ */
template <class IdType>
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids, std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids,
const IdArray& recv_ids); const IdArray& recv_ids);
/*! /*!
* \brief Generate degree bucketing schedule for group_apply edge * \brief Generate degree bucketing schedule for group_apply edge
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param uids One end vertex of edge by which edges are grouped * \param uids One end vertex of edge by which edges are grouped
* \param vids The other end vertex of edge * \param vids The other end vertex of edge
* \param eids Edge ids * \param eids Edge ids
...@@ -49,6 +52,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids ...@@ -49,6 +52,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids
* sections: number of edges in each degree bucket (used to partition * sections: number of edges in each degree bucket (used to partition
* new_uids, new_vids, and new_eids) * new_uids, new_vids, and new_eids)
*/ */
template <class IdType>
std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids, std::vector<IdArray> GroupEdgeByNodeDegree(const IdArray& uids,
const IdArray& vids, const IdArray& eids); const IdArray& vids, const IdArray& eids);
......
...@@ -1073,7 +1073,7 @@ def sort_1d(input): ...@@ -1073,7 +1073,7 @@ def sort_1d(input):
""" """
pass pass
def arange(start, stop): def arange(start, stop, dtype):
"""Create a 1D range int64 tensor. """Create a 1D range int64 tensor.
Parameters Parameters
...@@ -1082,6 +1082,8 @@ def arange(start, stop): ...@@ -1082,6 +1082,8 @@ def arange(start, stop):
The range start. The range start.
stop : int stop : int
The range stop. The range stop.
dtype: str
The dtype of result tensor
Returns Returns
------- -------
......
...@@ -338,11 +338,11 @@ def sort_1d(input): ...@@ -338,11 +338,11 @@ def sort_1d(input):
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype='int64')
return val, idx return val, idx
def arange(start, stop): def arange(start, stop, dtype="int64"):
if start >= stop: if start >= stop:
return nd.array([], dtype=np.int64) return nd.array([], dtype=data_type_dict()[dtype])
else: else:
return nd.arange(start, stop, dtype=np.int64) return nd.arange(start, stop, dtype=data_type_dict()[dtype])
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
......
...@@ -194,8 +194,8 @@ def nonzero_1d(input): ...@@ -194,8 +194,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return np.sort(input), np.argsort(input) return np.sort(input), np.argsort(input)
def arange(start, stop): def arange(start, stop, dtype="int64"):
return np.arange(start, stop, dtype=np.int64) return np.arange(start, stop, dtype=getattr(np, dtype))
def rand_shuffle(arr): def rand_shuffle(arr):
copy = np.copy(arr) copy = np.copy(arr)
......
...@@ -154,7 +154,7 @@ def repeat(input, repeats, dim): ...@@ -154,7 +154,7 @@ def repeat(input, repeats, dim):
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1) return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index) return th.index_select(data, 0, row_index.long())
def slice_axis(data, axis, begin, end): def slice_axis(data, axis, begin, end):
return th.narrow(data, axis, begin, end - begin) return th.narrow(data, axis, begin, end - begin)
...@@ -167,7 +167,7 @@ def narrow_row(x, start, stop): ...@@ -167,7 +167,7 @@ def narrow_row(x, start, stop):
return x[start:stop] return x[start:stop]
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return data.index_copy(0, row_index, value) return data.index_copy(0, row_index.long(), value)
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
data[row_index] = value data[row_index] = value
...@@ -263,8 +263,8 @@ def nonzero_1d(input): ...@@ -263,8 +263,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return th.sort(input) return th.sort(input)
def arange(start, stop): def arange(start, stop, dtype="int64"):
return th.arange(start, stop, dtype=th.int64) return th.arange(start, stop, dtype=data_type_dict()[dtype])
def rand_shuffle(arr): def rand_shuffle(arr):
idx = th.randperm(len(arr)) idx = th.randperm(len(arr))
......
...@@ -78,8 +78,10 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -78,8 +78,10 @@ def sparse_matrix(data, index, shape, force_format=False):
if fmt != 'coo': if fmt != 'coo':
raise TypeError( raise TypeError(
'Tensorflow backend only supports COO format. But got %s.' % fmt) 'Tensorflow backend only supports COO format. But got %s.' % fmt)
spmat = tf.SparseTensor(indices=tf.transpose( # tf.SparseTensor only supports int64 indexing,
index[1], (1, 0)), values=data, dense_shape=shape) # therefore manually casting to int64 when input in int32
spmat = tf.SparseTensor(indices=tf.cast(tf.transpose(
index[1], (1, 0)), tf.int64), values=data, dense_shape=shape)
return spmat, None return spmat, None
...@@ -372,9 +374,9 @@ def sort_1d(input): ...@@ -372,9 +374,9 @@ def sort_1d(input):
return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64) return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)
def arange(start, stop): def arange(start, stop, dtype="int64"):
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
t = tf.range(start, stop, dtype=tf.int64) t = tf.range(start, stop, dtype=data_type_dict()[dtype])
return t return t
......
...@@ -209,8 +209,8 @@ def metapath_random_walk(hg, etypes, seeds, num_traces): ...@@ -209,8 +209,8 @@ def metapath_random_walk(hg, etypes, seeds, num_traces):
raise ValueError('beginning and ending node type mismatch') raise ValueError('beginning and ending node type mismatch')
if len(seeds) == 0: if len(seeds) == 0:
return [] return []
etype_array = ndarray.array(np.asarray([hg.get_etype_id(et) for et in etypes], dtype='int64')) etype_array = ndarray.array(np.asarray([hg.get_etype_id(et) for et in etypes], dtype="int64"))
seed_array = utils.toindex(seeds).todgltensor() seed_array = utils.toindex(seeds, hg._idtype_str).todgltensor()
traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces) traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces)
return _split_traces(traces) return _split_traces(traces)
......
...@@ -22,7 +22,7 @@ __all__ = [ ...@@ -22,7 +22,7 @@ __all__ = [
] ]
def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True, def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True,
restrict_format='any', **kwargs): restrict_format='any', index_dtype="int64", **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -135,22 +135,22 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True ...@@ -135,22 +135,22 @@ def graph(data, ntype='_N', etype='_E', num_nodes=None, card=None, validate=True
u, v = data u, v = data
return create_from_edges( return create_from_edges(
u, v, ntype, etype, ntype, urange, vrange, validate, u, v, ntype, etype, ntype, urange, vrange, validate,
restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list( return create_from_edge_list(
data, ntype, etype, ntype, urange, vrange, validate, data, ntype, etype, ntype, urange, vrange, validate,
restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy( return create_from_scipy(
data, ntype, etype, ntype, restrict_format=restrict_format) data, ntype, etype, ntype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
return create_from_networkx( return create_from_networkx(
data, ntype, etype, restrict_format=restrict_format, **kwargs) data, ntype, etype, restrict_format=restrict_format, index_dtype=index_dtype, **kwargs)
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None, def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=None,
validate=True, restrict_format='any', **kwargs): validate=True, restrict_format='any', index_dtype='int64', **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -282,18 +282,19 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non ...@@ -282,18 +282,19 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', num_nodes=None, card=Non
if isinstance(data, tuple): if isinstance(data, tuple):
u, v = data u, v = data
return create_from_edges( return create_from_edges(
u, v, utype, etype, vtype, urange, vrange, validate, u, v, utype, etype, vtype, urange, vrange, validate, index_dtype=index_dtype,
restrict_format=restrict_format) restrict_format=restrict_format)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list( return create_from_edge_list(
data, utype, etype, vtype, urange, vrange, validate, data, utype, etype, vtype, urange, vrange, validate, index_dtype=index_dtype,
restrict_format=restrict_format) restrict_format=restrict_format)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy( return create_from_scipy(
data, utype, etype, vtype, restrict_format=restrict_format) data, utype, etype, vtype, restrict_format=restrict_format, index_dtype=index_dtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
return create_from_networkx_bipartite( return create_from_networkx_bipartite(data, utype, etype,
data, utype, etype, vtype, restrict_format=restrict_format, **kwargs) vtype, restrict_format=restrict_format,
index_dtype=index_dtype, **kwargs)
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
...@@ -385,13 +386,18 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -385,13 +386,18 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
ntypes = list(sorted(ntype_set)) ntypes = list(sorted(ntype_set))
else: else:
ntypes = list(sorted(num_nodes_per_type.keys())) ntypes = list(sorted(num_nodes_per_type.keys()))
num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes]) num_nodes_per_type = utils.toindex([num_nodes_per_type[ntype] for ntype in ntypes], "int64")
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)} ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
index_dtype = rel_graphs[0]._idtype_str
for rgrh in rel_graphs: for rgrh in rel_graphs:
if rgrh._idtype_str != index_dtype:
raise Exception("Expect relation graphs to be {}, but got {}".format(
index_dtype, rgrh._idtype_str))
stype, etype, dtype = rgrh.canonical_etypes[0] stype, etype, dtype = rgrh.canonical_etypes[0]
meta_edges_src.append(ntype_dict[stype]) meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype]) meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype) etypes.append(etype)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True) metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
# create graph index # create graph index
...@@ -404,7 +410,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -404,7 +410,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
retg._edge_frames[i].update(rgrh._edge_frames[0]) retg._edge_frames[i].update(rgrh._edge_frames[0])
return retg return retg
def heterograph(data_dict, num_nodes_dict=None): def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
"""Create a heterogeneous graph from a dictionary between edge types and edge lists. """Create a heterogeneous graph from a dictionary between edge types and edge lists.
Parameters Parameters
...@@ -484,15 +490,18 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -484,15 +490,18 @@ def heterograph(data_dict, num_nodes_dict=None):
elif srctype == dsttype: elif srctype == dsttype:
rel_graphs.append(graph( rel_graphs.append(graph(
data, srctype, etype, data, srctype, etype,
num_nodes=num_nodes_dict[srctype], validate=False)) num_nodes=num_nodes_dict[srctype], validate=False, index_dtype=index_dtype))
else: else:
rel_graphs.append(bipartite( rel_graphs.append(bipartite(
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False)) num_nodes=(num_nodes_dict[srctype], num_nodes_dict[dsttype]),
validate=False, index_dtype=index_dtype))
return hetero_from_relations(rel_graphs, num_nodes_dict) return hetero_from_relations(rel_graphs, num_nodes_dict)
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None):
def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE,
metagraph=None):
"""Convert the given homogeneous graph to a heterogeneous graph. """Convert the given homogeneous graph to a heterogeneous graph.
The input graph should have only one type of nodes and edges. Each node and edge The input graph should have only one type of nodes and edges. Each node and edge
...@@ -588,6 +597,7 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -588,6 +597,7 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
' type of nodes and edges.') ' type of nodes and edges.')
num_ntypes = len(ntypes) num_ntypes = len(ntypes)
index_dtype = G._idtype_str
ntype_ids = F.asnumpy(G.ndata[ntype_field]) ntype_ids = F.asnumpy(G.ndata[ntype_field])
etype_ids = F.asnumpy(G.edata[etype_field]) etype_ids = F.asnumpy(G.edata[etype_field])
...@@ -641,15 +651,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -641,15 +651,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if stid == dtid: if stid == dtid:
rel_graph = graph( rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
num_nodes=ntype_count[stid], validate=False) num_nodes=ntype_count[stid], validate=False, index_dtype=index_dtype)
else: else:
rel_graph = bipartite( rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid], (src_of_etype,
num_nodes=(ntype_count[stid], ntype_count[dtid]), validate=False) dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
num_nodes=(ntype_count[stid], ntype_count[dtid]),
validate=False, index_dtype=index_dtype)
rel_graphs.append(rel_graph) rel_graphs.append(rel_graph)
hg = hetero_from_relations( hg = hetero_from_relations(rel_graphs,
rel_graphs, {ntype: count for ntype, count in zip(ntypes, ntype_count)}) {ntype: count for ntype, count in zip(
ntypes, ntype_count)})
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)} ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes): for ntid, ntype in enumerate(hg.ntypes):
...@@ -722,7 +735,7 @@ def to_homo(G): ...@@ -722,7 +735,7 @@ def to_homo(G):
num_nodes = G.number_of_nodes(ntype) num_nodes = G.number_of_nodes(ntype)
total_num_nodes += num_nodes total_num_nodes += num_nodes
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu())) ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes)) nids.append(F.arange(0, num_nodes, G._idtype_str))
for etype_id, etype in enumerate(G.canonical_etypes): for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype srctype, _, dsttype = etype
...@@ -731,9 +744,10 @@ def to_homo(G): ...@@ -731,9 +744,10 @@ def to_homo(G):
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)])) srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)])) dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu())) etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges)) eids.append(F.arange(0, num_edges, G._idtype_str))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, validate=False) retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
validate=False, index_dtype=G._idtype_str)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0) retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0) retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0) retg.edata[ETYPE] = F.cat(etype_ids, 0)
...@@ -754,7 +768,7 @@ def to_homo(G): ...@@ -754,7 +768,7 @@ def to_homo(G):
############################################################ ############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True, def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True,
restrict_format="any"): restrict_format="any", index_dtype='int64'):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -786,8 +800,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -786,8 +800,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
------- -------
DGLHeteroGraph DGLHeteroGraph
""" """
u = utils.toindex(u) u = utils.toindex(u, index_dtype)
v = utils.toindex(v) v = utils.toindex(v, index_dtype)
if validate: if validate:
if urange is not None and len(u) > 0 and \ if urange is not None and len(u) > 0 and \
...@@ -817,7 +831,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -817,7 +831,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
validate=True, restrict_format='any'): validate=True, restrict_format='any', index_dtype='int64'):
"""Internal function to create a heterograph from a list of edge tuples with types. """Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -853,11 +867,11 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, ...@@ -853,11 +867,11 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None,
u, v = zip(*elist) u, v = zip(*elist)
u = list(u) u = list(u)
v = list(v) v = list(v)
return create_from_edges( return create_from_edges(u, v, utype, etype, vtype, urange, vrange,
u, v, utype, etype, vtype, urange, vrange, validate, restrict_format) validate, restrict_format, index_dtype=index_dtype)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Internal function to create a heterograph from a scipy sparse matrix with types. """Internal function to create a heterograph from a scipy sparse matrix with types.
Parameters Parameters
...@@ -889,16 +903,16 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False, ...@@ -889,16 +903,16 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
num_src, num_dst = spmat.shape num_src, num_dst = spmat.shape
num_ntypes = 1 if utype == vtype else 2 num_ntypes = 1 if utype == vtype else 2
if spmat.getformat() == 'coo': if spmat.getformat() == 'coo':
row = utils.toindex(spmat.row) row = utils.toindex(spmat.row.astype(index_dtype), index_dtype)
col = utils.toindex(spmat.col) col = utils.toindex(spmat.col.astype(index_dtype), index_dtype)
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, num_src, num_dst, row, col, restrict_format) num_ntypes, num_src, num_dst, row, col, restrict_format)
else: else:
spmat = spmat.tocsr() spmat = spmat.tocsr()
indptr = utils.toindex(spmat.indptr) indptr = utils.toindex(spmat.indptr.astype(index_dtype), index_dtype)
indices = utils.toindex(spmat.indices) indices = utils.toindex(spmat.indices.astype(index_dtype), index_dtype)
# TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix? # TODO(minjie): with_edge_id is only reasonable for csr matrix. How to fix?
data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices)))) data = utils.toindex(spmat.data if with_edge_id else list(range(len(indices))), index_dtype)
hgidx = heterograph_index.create_unitgraph_from_csr( hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, num_src, num_dst, indptr, indices, data, restrict_format) num_ntypes, num_src, num_dst, indptr, indices, data, restrict_format)
if num_ntypes == 1: if num_ntypes == 1:
...@@ -911,7 +925,7 @@ def create_from_networkx(nx_graph, ...@@ -911,7 +925,7 @@ def create_from_networkx(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Create a heterograph that has only one set of nodes and edges. """Create a heterograph that has only one set of nodes and edges.
Parameters Parameters
...@@ -949,8 +963,8 @@ def create_from_networkx(nx_graph, ...@@ -949,8 +963,8 @@ def create_from_networkx(nx_graph,
if has_edge_id: if has_edge_id:
num_edges = nx_graph.number_of_edges() num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name] eid = attr[edge_id_attr_name]
src[eid] = u src[eid] = u
...@@ -961,11 +975,11 @@ def create_from_networkx(nx_graph, ...@@ -961,11 +975,11 @@ def create_from_networkx(nx_graph,
for e in nx_graph.edges: for e in nx_graph.edges:
src.append(e[0]) src.append(e[0])
dst.append(e[1]) dst.append(e[1])
src = utils.toindex(src) src = utils.toindex(src, index_dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, index_dtype)
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes, g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes,
validate=False, restrict_format=restrict_format) validate=False, restrict_format=restrict_format, index_dtype=index_dtype)
# handle features # handle features
# copy attributes # copy attributes
...@@ -1017,7 +1031,7 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1017,7 +1031,7 @@ def create_from_networkx_bipartite(nx_graph,
edge_id_attr_name='id', edge_id_attr_name='id',
node_attrs=None, node_attrs=None,
edge_attrs=None, edge_attrs=None,
restrict_format='any'): restrict_format='any', index_dtype='int64'):
"""Create a heterograph that has one set of source nodes, one set of """Create a heterograph that has one set of source nodes, one set of
destination nodes and one set of edges. destination nodes and one set of edges.
...@@ -1065,8 +1079,8 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1065,8 +1079,8 @@ def create_from_networkx_bipartite(nx_graph,
if has_edge_id: if has_edge_id:
num_edges = nx_graph.number_of_edges() num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=getattr(np, index_dtype))
for u, v, attr in nx_graph.edges(data=True): for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name] eid = attr[edge_id_attr_name]
src[eid] = top_map[u] src[eid] = top_map[u]
...@@ -1078,11 +1092,11 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1078,11 +1092,11 @@ def create_from_networkx_bipartite(nx_graph,
if e[0] in top_map: if e[0] in top_map:
src.append(top_map[e[0]]) src.append(top_map[e[0]])
dst.append(bottom_map[e[1]]) dst.append(bottom_map[e[1]])
src = utils.toindex(src) src = utils.toindex(src, index_dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, index_dtype)
g = create_from_edges( g = create_from_edges(src, dst, utype, etype, vtype,
src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes), validate=False,
len(top_nodes), len(bottom_nodes), validate=False, restrict_format=restrict_format) restrict_format=restrict_format, index_dtype=index_dtype)
# TODO attributes # TODO attributes
assert node_attrs is None, 'Retrieval of node attributes are not supported yet.' assert node_attrs is None, 'Retrieval of node attributes are not supported yet.'
......
...@@ -793,6 +793,29 @@ class DGLBaseGraph(object): ...@@ -793,6 +793,29 @@ class DGLBaseGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor() return self._graph.out_degrees(v).tousertensor()
@property
def idtype(self):
"""Return the dtype of the graph index
Returns
---------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return getattr(F, self._graph.dtype)
@property
def _idtype_str(self):
"""The dtype of graph index
Returns
-------
backend dtype object
th.int32/th.int64 or tf.int32/tf.int64 etc.
"""
return self._graph.dtype
def mutation(func): def mutation(func):
"""A decorator to decorate functions that might change graph structure.""" """A decorator to decorate functions that might change graph structure."""
......
...@@ -864,6 +864,21 @@ class GraphIndex(ObjectBase): ...@@ -864,6 +864,21 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLGraphContext(self) return _CAPI_DGLGraphContext(self)
@property
def dtype(self):
"""Return the index dtype
Returns
----------
str
The dtype of graph index
"""
bits = self.nbits()
if bits == 32:
return "int32"
else:
return "int64"
def copy_to(self, ctx): def copy_to(self, ctx):
"""Copy this immutable graph index to the given device context. """Copy this immutable graph index to the given device context.
......
This diff is collapsed.
...@@ -129,6 +129,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -129,6 +129,7 @@ class HeteroGraphIndex(ObjectBase):
_CAPI_DGLHeteroClear(self) _CAPI_DGLHeteroClear(self)
self._cache.clear() self._cache.clear()
@property
def dtype(self): def dtype(self):
"""Return the data type of this graph index. """Return the data type of this graph index.
...@@ -149,16 +150,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -149,16 +150,6 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroContext(self) return _CAPI_DGLHeteroContext(self)
def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Returns
-------
int
The number of bits.
"""
return _CAPI_DGLHeteroNumBits(self)
def bits_needed(self, etype): def bits_needed(self, etype):
"""Return the number of integer bits needed to represent the unitgraph graph. """Return the number of integer bits needed to represent the unitgraph graph.
...@@ -298,7 +289,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -298,7 +289,7 @@ class HeteroGraphIndex(ObjectBase):
0-1 array indicating existence 0-1 array indicating existence
""" """
vid_array = vids.todgltensor() vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array)) return utils.toindex(_CAPI_DGLHeteroHasVertices(self, int(ntype), vid_array), self.dtype)
def has_edge_between(self, etype, u, v): def has_edge_between(self, etype, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
...@@ -339,7 +330,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -339,7 +330,7 @@ class HeteroGraphIndex(ObjectBase):
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween( return utils.toindex(_CAPI_DGLHeteroHasEdgesBetween(
self, int(etype), u_array, v_array)) self, int(etype), u_array, v_array), self.dtype)
def predecessors(self, etype, v): def predecessors(self, etype, v):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -359,7 +350,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -359,7 +350,7 @@ class HeteroGraphIndex(ObjectBase):
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLHeteroPredecessors( return utils.toindex(_CAPI_DGLHeteroPredecessors(
self, int(etype), int(v))) self, int(etype), int(v)), self.dtype)
def successors(self, etype, v): def successors(self, etype, v):
"""Return the successors of the node. """Return the successors of the node.
...@@ -379,7 +370,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -379,7 +370,7 @@ class HeteroGraphIndex(ObjectBase):
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLHeteroSuccessors( return utils.toindex(_CAPI_DGLHeteroSuccessors(
self, int(etype), int(v))) self, int(etype), int(v)), self.dtype)
def edge_id(self, etype, u, v): def edge_id(self, etype, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -399,7 +390,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -399,7 +390,7 @@ class HeteroGraphIndex(ObjectBase):
The edge id array. The edge id array.
""" """
return utils.toindex(_CAPI_DGLHeteroEdgeId( return utils.toindex(_CAPI_DGLHeteroEdgeId(
self, int(etype), int(u), int(v))) self, int(etype), int(u), int(v)), self.dtype)
def edge_ids(self, etype, u, v): def edge_ids(self, etype, u, v):
"""Return a triplet of arrays that contains the edge IDs. """Return a triplet of arrays that contains the edge IDs.
...@@ -426,9 +417,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -426,9 +417,9 @@ class HeteroGraphIndex(ObjectBase):
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array) edge_array = _CAPI_DGLHeteroEdgeIds(self, int(etype), u_array, v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
...@@ -454,9 +445,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -454,9 +445,9 @@ class HeteroGraphIndex(ObjectBase):
eid_array = eid.todgltensor() eid_array = eid.todgltensor()
edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array) edge_array = _CAPI_DGLHeteroFindEdges(self, int(etype), eid_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
...@@ -486,9 +477,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -486,9 +477,9 @@ class HeteroGraphIndex(ObjectBase):
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array) edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
def out_edges(self, etype, v): def out_edges(self, etype, v):
...@@ -517,9 +508,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -517,9 +508,9 @@ class HeteroGraphIndex(ObjectBase):
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array) edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0), self.dtype)
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1), self.dtype)
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2), self.dtype)
return src, dst, eid return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache='_cache', prefix='edges')
...@@ -552,9 +543,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -552,9 +543,9 @@ class HeteroGraphIndex(ObjectBase):
src = edge_array(0) src = edge_array(0)
dst = edge_array(1) dst = edge_array(1)
eid = edge_array(2) eid = edge_array(2)
src = utils.toindex(src) src = utils.toindex(src, self.dtype)
dst = utils.toindex(dst) dst = utils.toindex(dst, self.dtype)
eid = utils.toindex(eid) eid = utils.toindex(eid, self.dtype)
return src, dst, eid return src, dst, eid
def in_degree(self, etype, v): def in_degree(self, etype, v):
...@@ -594,7 +585,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -594,7 +585,7 @@ class HeteroGraphIndex(ObjectBase):
The in degree array. The in degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array)) return utils.toindex(_CAPI_DGLHeteroInDegrees(self, int(etype), v_array), self.dtype)
def out_degree(self, etype, v): def out_degree(self, etype, v):
"""Return the out degree of the node. """Return the out degree of the node.
...@@ -633,7 +624,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -633,7 +624,7 @@ class HeteroGraphIndex(ObjectBase):
The out degree array. The out degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array)) return utils.toindex(_CAPI_DGLHeteroOutDegrees(self, int(etype), v_array), self.dtype)
def adjacency_matrix(self, etype, transpose, ctx): def adjacency_matrix(self, etype, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -672,18 +663,20 @@ class HeteroGraphIndex(ObjectBase): ...@@ -672,18 +663,20 @@ class HeteroGraphIndex(ObjectBase):
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx) indptr = F.copy_to(utils.toindex(rst(0), self.dtype).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx) indices = F.copy_to(utils.toindex(rst(1), self.dtype).tousertensor(), ctx)
shuffle = utils.toindex(rst(2)) shuffle = utils.toindex(rst(2), self.dtype)
dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type dat = F.ones(nnz, dtype=F.float32, ctx=ctx) # FIXME(minjie): data type
spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0] spmat = F.sparse_matrix(dat, ('csr', indices, indptr), (nrows, ncols))[0]
return spmat, shuffle return spmat, shuffle
elif fmt == "coo": elif fmt == "coo":
idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx) idx = F.copy_to(utils.toindex(rst(0), self.dtype).tousertensor(), ctx)
idx = F.reshape(idx, (2, nnz)) idx = F.reshape(idx, (2, nnz))
dat = F.ones((nnz,), dtype=F.float32, ctx=ctx) dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)
adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (nrows, ncols)) adj, shuffle_idx = F.sparse_matrix(
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None dat, ('coo', idx), (nrows, ncols))
shuffle_idx = utils.toindex(
shuffle_idx, self.dtype) if shuffle_idx is not None else None
return adj, shuffle_idx return adj, shuffle_idx
else: else:
raise Exception("unknown format") raise Exception("unknown format")
...@@ -732,12 +725,12 @@ class HeteroGraphIndex(ObjectBase): ...@@ -732,12 +725,12 @@ class HeteroGraphIndex(ObjectBase):
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype) nnz = self.number_of_edges(etype)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0), self.dtype).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1), self.dtype).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo': elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy() idx = utils.toindex(rst(0), self.dtype).tonumpy()
row, col = np.reshape(idx, (2, nnz)) row, col = np.reshape(idx, (2, nnz))
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row) data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols)) return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
...@@ -906,7 +899,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -906,7 +899,8 @@ class HeteroGraphIndex(ObjectBase):
order = csr(2) order = csr(2)
rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr") rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, "csr")
rev_order = rev_csr(2) rev_order = rev_csr(2)
return utils.toindex(order), utils.toindex(rev_order) return utils.toindex(order, self.dtype), utils.toindex(rev_order, self.dtype)
@register_object('graph.HeteroSubgraph') @register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
...@@ -933,7 +927,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -933,7 +927,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced nodes Induced nodes
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self) ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data, self.graph.dtype) for v in ret]
@property @property
def induced_edges(self): def induced_edges(self):
...@@ -946,7 +940,7 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -946,7 +940,7 @@ class HeteroSubgraphIndex(ObjectBase):
Induced edges Induced edges
""" """
ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self) ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)
return [utils.toindex(v.data) for v in ret] return [utils.toindex(v.data, self.graph.dtype) for v in ret]
################################################################# #################################################################
# Creators # Creators
......
...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -136,11 +136,11 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBinaryOpReduce( _CAPI_DGLKernelBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce( ...@@ -200,11 +200,11 @@ def backward_lhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardLhsBinaryOpReduce( _CAPI_DGLKernelBackwardLhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce( ...@@ -265,11 +265,11 @@ def backward_rhs_binary_op_reduce(
The rows written to output tensor. The rows written to output tensor.
""" """
if A_rows is None: if A_rows is None:
A_rows = nd.NULL A_rows = nd.NULL[G.dtype]
if B_rows is None: if B_rows is None:
B_rows = nd.NULL B_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardRhsBinaryOpReduce( _CAPI_DGLKernelBackwardRhsBinaryOpReduce(
reducer, op, G, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target, ...@@ -364,9 +364,9 @@ def copy_reduce(reducer, G, target,
The rows to write to output tensor. The rows to write to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = nd.NULL X_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelCopyReduce( _CAPI_DGLKernelCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, X_rows, out_rows) X, out, X_rows, out_rows)
...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target, ...@@ -406,9 +406,9 @@ def backward_copy_reduce(reducer, G, target,
The rows written to output tensor. The rows written to output tensor.
""" """
if X_rows is None: if X_rows is None:
X_rows = nd.NULL X_rows = nd.NULL[G.dtype]
if out_rows is None: if out_rows is None:
out_rows = nd.NULL out_rows = nd.NULL[G.dtype]
_CAPI_DGLKernelBackwardCopyReduce( _CAPI_DGLKernelBackwardCopyReduce(
reducer, G, int(target), reducer, G, int(target),
X, out, grad_out, grad_X, X, out, grad_out, grad_X,
......
...@@ -177,4 +177,7 @@ _init_api("dgl.ndarray") ...@@ -177,4 +177,7 @@ _init_api("dgl.ndarray")
# An array representing null (no value) that can be safely converted to # An array representing null (no value) that can be safely converted to
# other backend tensors. # other backend tensors.
NULL = array(_np.array([], dtype=_np.int64)) NULL = {
"int64": array(_np.array([], dtype=_np.int64)),
"int32": array(_np.array([], dtype=_np.int32))
}
...@@ -70,7 +70,7 @@ def choice(a, size, replace=True, prob=None): # pylint: disable=invalid-name ...@@ -70,7 +70,7 @@ def choice(a, size, replace=True, prob=None): # pylint: disable=invalid-name
population = a population = a
if prob is None: if prob is None:
prob = nd.NULL prob = nd.NULL["int64"]
else: else:
prob = F.zerocopy_to_dgl_ndarray(prob) prob = F.zerocopy_to_dgl_ndarray(prob)
......
"""Module for degree bucketing schedulers.""" """Module for degree bucketing schedulers."""
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -118,11 +119,12 @@ def _process_node_buckets(buckets): ...@@ -118,11 +119,12 @@ def _process_node_buckets(buckets):
The zero-degree nodes The zero-degree nodes
""" """
# get back results # get back results
degs = utils.toindex(buckets(0)) dtype = buckets(0).dtype
v = utils.toindex(buckets(1)) degs = utils.toindex(buckets(0), dtype)
v = utils.toindex(buckets(1), dtype)
# XXX: convert directly from ndarary to python list? # XXX: convert directly from ndarary to python list?
v_section = buckets(2).asnumpy().tolist() v_section = buckets(2).asnumpy().tolist()
msg_ids = utils.toindex(buckets(3)) msg_ids = utils.toindex(buckets(3), dtype)
msg_section = buckets(4).asnumpy().tolist() msg_section = buckets(4).asnumpy().tolist()
# split buckets # split buckets
...@@ -131,8 +133,8 @@ def _process_node_buckets(buckets): ...@@ -131,8 +133,8 @@ def _process_node_buckets(buckets):
msg_ids = F.split(msg_ids, msg_section, 0) msg_ids = F.split(msg_ids, msg_section, 0)
# convert to utils.Index # convert to utils.Index
dsts = [utils.toindex(dst) for dst in dsts] dsts = [utils.toindex(dst, dtype) for dst in dsts]
msg_ids = [utils.toindex(msg_id) for msg_id in msg_ids] msg_ids = [utils.toindex(msg_id, dtype) for msg_id in msg_ids]
# handle zero deg # handle zero deg
degs = degs.tonumpy() degs = degs.tonumpy()
...@@ -266,17 +268,18 @@ def _process_edge_buckets(buckets): ...@@ -266,17 +268,18 @@ def _process_edge_buckets(buckets):
A list of edge id buckets A list of edge id buckets
""" """
# get back results # get back results
dtype = buckets(0).dtype
degs = buckets(0).asnumpy() degs = buckets(0).asnumpy()
uids = utils.toindex(buckets(1)) uids = utils.toindex(buckets(1), dtype)
vids = utils.toindex(buckets(2)) vids = utils.toindex(buckets(2), dtype)
eids = utils.toindex(buckets(3)) eids = utils.toindex(buckets(3), dtype)
# XXX: convert directly from ndarary to python list? # XXX: convert directly from ndarary to python list?
sections = buckets(4).asnumpy().tolist() sections = buckets(4).asnumpy().tolist()
# split buckets and convert to index # split buckets and convert to index
def split(to_split): def split(to_split):
res = F.split(to_split.tousertensor(), sections, 0) res = F.split(to_split.tousertensor(), sections, 0)
return map(utils.toindex, res) return map(partial(utils.toindex, dtype=dtype), res)
uids = split(uids) uids = split(uids)
vids = split(vids) vids = split(vids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment