Unverified Commit 44089c8b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor][Graph] Merge DGLGraph and DGLHeteroGraph (#1862)



* Merge

* [Graph][CUDA] Graph on GPU and many refactoring (#1791)

* change edge_ids behavior and C++ impl

* fix unittests; remove utils.Index in edge_id

* pass mx and th tests

* pass tf test

* add aten::Scatter_

* Add nonzero; impl CSRGetDataAndIndices/CSRSliceMatrix

* CSRGetData and CSRGetDataAndIndices passed tests

* CSRSliceMatrix basic tests

* fix bug in empty slice

* CUDA CSRHasDuplicate

* has_node; has_edge_between

* predecessors, successors

* deprecate send/recv; fix send_and_recv

* deprecate send/recv; fix send_and_recv

* in_edges; out_edges; all_edges; apply_edges

* in deg/out deg

* subgraph/edge_subgraph

* adj

* in_subgraph/out_subgraph

* sample neighbors

* set/get_n/e_repr

* wip: working on refactoring all idtypes

* pass ndata/edata tests on gpu

* fix

* stash

* workaround nonzero issue

* stash

* nx conversion

* test_hetero_basics except update routines

* test_update_routines

* test_hetero_basics for pytorch

* more fixes

* WIP: flatten graph

* wip: flatten

* test_flatten

* test_to_device

* fix bug in to_homo

* fix bug in CSRSliceMatrix

* pass subgraph test

* fix send_and_recv

* fix filter

* test_heterograph

* passed all pytorch tests

* fix mx unittest

* fix pytorch test_nn

* fix all unittests for PyTorch

* passed all mxnet tests

* lint

* fix tf nn test

* pass all tf tests

* lint

* lint

* change deprecation

* try fix compile

* lint

* update METIDS

* fix utest

* fix

* fix utests

* try debug

* revert

* small fix

* fix utests

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [kernel] Use heterograph index instead of unitgraph index (#1813)

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [Graph] Mutation for Heterograph (#1818)

* mutation add_nodes and add_edges

* Add support for remove_edges, remove_nodes, add_selfloop, remove_selfloop

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* upd

* upd

* upd

* fix

* [Transfom] Mutable transform (#1833)

* add nodesy

* All three

* Fix

* lint

* Add some test case

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* fix

* triger

* Fix

* fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* [Graph] Migrate Batch & Readout module to heterograph (#1836)

* dgl.batch

* unbatch

* fix to device

* reduce readout; segment reduce

* change batch_num_nodes|edges to function

* reduce readout/ softmax

* broadcast

* topk

* fix

* fix tf and mx

* fix some ci

* fix batch but unbatch differently

* new checkk

* upd

* upd

* upd

* idtype behavior; code reorg

* idtype behavior; code reorg

* wip: test_basics

* pass test_basics

* WIP: from nx/ to nx

* missing files

* upd

* pass test_basics:test_nx_conversion

* Fix test

* Fix inplace update

* WIP: fixing tests

* upd

* pass test_transform cpu

* pass gpu test_transform

* pass test_batched_graph

* GPU graph auto cast to int32

* missing file

* stash

* WIP: rgcn-hetero

* Fix two datasety

* upd

* weird

* Fix capsuley

* fuck you

* fuck matthias

* Fix dgmg

* fix bug in block degrees; pass rgcn-hetero

* rgcn

* gat and diffpool fix
also fix ppi and tu dataset

* Tree LSTM

* pointcloud

* rrn; wip: sgc

* resolve conflicts

* upd

* sgc and reddit dataset

* upd

* Fix deepwalk, gindt and gcn

* fix datasets and sign

* optimization

* optimization

* upd

* upd

* Fix GIN

* fix bug in add_nodes add_edges; tagcn

* adaptive sampling and gcmc

* upd

* upd

* fix geometric

* fix

* metapath2vec

* fix agnn

* fix pickling problem of block

* fix utests

* miss file

* linegraph

* upd

* upd

* upd

* graphsage

* stgcn_wave

* fix hgt

* on unittests

* Fix transformer

* Fix HAN

* passed pytorch unittests

* lint

* fix

* Fix cluster gcn

* cluster-gcn is ready

* on fixing block related codes

* 2nd order derivative

* Revert "2nd order derivative"

This reverts commit 523bf6c249bee61b51b1ad1babf42aad4167f206.

* passed torch utests again

* fix all mxnet unittests

* delete some useless tests

* pass all tf cpu tests

* disable

* disable distributed unittest

* fix

* fix

* lint

* fix

* fix

* fix script

* fix tutorial

* fix apply edges bug

* fix 2 basics

* fix tutorial
Co-authored-by: default avataryzh119 <expye@outlook.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-7-42.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
parent 015acfd2
"""Checking and logging utilities."""
# pylint: disable=invalid-name
from __future__ import absolute_import, division
from ..base import DGLError, dgl_warning
from .. import backend as F
from .internal import to_dgl_context
def prepare_tensor(g, data, name):
"""Convert the data to ID tensor and check its ID type and context.
If the data is already in tensor type, raise error if its ID type
and context does not match the graph's.
Otherwise, convert it to tensor type of the graph's ID type and
ctx and return.
Parameters
----------
g : DGLHeteroGraph
Graph.
data : int, iterable of int, tensor
Data.
name : str
Name of the data.
Returns
-------
Tensor
Data in tensor object.
"""
ret = None
if F.is_tensor(data):
if F.dtype(data) != g.idtype or F.context(data) != g.device:
raise DGLError('Expect argument "{}" to have data type {} and device '
'context {}. But got {} and {}.'.format(
name, g.idtype, g.device, F.dtype(data), F.context(data)))
ret = data
else:
ret = F.copy_to(F.tensor(data, g.idtype), g.device)
if F.ndim(ret) != 1:
raise DGLError('Expect a 1-D tensor for argument "{}". But got {}.'.format(
name, ret))
return ret
def prepare_tensor_dict(g, data, name):
"""Convert a dictionary of data to a dictionary of ID tensors.
If calls ``prepare_tensor`` on each key-value pair.
Parameters
----------
g : DGLHeteroGraph
Graph.
data : dict[str, (int, iterable of int, tensor)]
Data dict.
name : str
Name of the data.
Returns
-------
dict[str, tensor]
"""
return {key : prepare_tensor(g, val, '{}["{}"]'.format(name, key))
for key, val in data.items()}
def check_all_same_idtype(glist, name):
"""Check all the graphs have the same idtype."""
if len(glist) == 0:
return
idtype = glist[0].idtype
for i, g in enumerate(glist):
if g.idtype != idtype:
raise DGLError('Expect {}[{}] to have {} type ID, but got {}.'.format(
name, i, idtype, g.idtype))
def check_all_same_device(glist, name):
"""Check all the graphs have the same device."""
if len(glist) == 0:
return
device = glist[0].device
for i, g in enumerate(glist):
if g.device != device:
raise DGLError('Expect {}[{}] to be on device {}, but got {}.'.format(
name, i, device, g.device))
def check_all_same_keys(dict_list, name):
"""Check all the dictionaries have the same set of keys."""
if len(dict_list) == 0:
return
keys = dict_list[0].keys()
for dct in dict_list:
if keys != dct.keys():
raise DGLError('Expect all {} to have the same set of keys, but got'
' {} and {}.'.format(name, keys, dct.keys()))
def check_all_have_keys(dict_list, keys, name):
"""Check the dictionaries all have the given keys."""
if len(dict_list) == 0:
return
keys = set(keys)
for dct in dict_list:
if not keys.issubset(dct.keys()):
raise DGLError('Expect all {} to include keys {}, but got {}.'.format(
name, keys, dct.keys()))
def check_all_same_schema(feat_dict_list, keys, name):
"""Check the features of the given keys all have the same schema.
Suggest calling ``check_all_have_keys`` first.
Parameters
----------
feat_dict_list : list[dict[str, Tensor]]
Feature dictionaries.
keys : list[str]
Keys
name : str
Name of this feature dict.
"""
if len(feat_dict_list) == 0:
return
for fdict in feat_dict_list:
for k in keys:
t1 = feat_dict_list[0][k]
t2 = fdict[k]
if F.dtype(t1) != F.dtype(t2) or F.shape(t1)[1:] != F.shape(t2)[1:]:
raise DGLError('Expect all features {}["{}"] to have the same data type'
' and feature size, but got\n\t{} {}\nand\n\t{} {}.'.format(
name, k, F.dtype(t1), F.shape(t1)[1:],
F.dtype(t2), F.shape(t2)[1:]))
def to_int32_graph_if_on_gpu(g):
"""Convert to int32 graph if the input graph is on GPU."""
# device_type 2 is an internal code for GPU
if to_dgl_context(g.device).device_type == 2 and g.idtype == F.int64:
dgl_warning('Automatically cast a GPU int64 graph to int32.\n'
' To suppress the warning, call DGLGraph.int() first\n'
' or specify the ``device`` argument when creating the graph.')
return g.int()
else:
return g
"""Data utilities."""
import scipy as sp
import networkx as nx
from ..base import DGLError
from .. import backend as F
def elist2tensor(elist, idtype):
"""Function to convert an edge list to edge tensors.
Parameters
----------
elist : iterable of int pairs
List of (src, dst) node ID pairs.
idtype : int32, int64, optional
Integer ID type. Must be int32 or int64.
Returns
-------
(Tensor, Tensor)
Edge tensors.
"""
if len(elist) == 0:
u, v = [], []
else:
u, v = zip(*elist)
u = list(u)
v = list(v)
return F.tensor(u, idtype), F.tensor(v, idtype)
def scipy2tensor(spmat, idtype):
"""Function to convert a scipy matrix to edge tensors.
Parameters
----------
spmat : scipy.sparse.spmatrix
SciPy sparse matrix.
idtype : int32, int64, optional
Integer ID type. Must be int32 or int64.
Returns
-------
(Tensor, Tensor)
Edge tensors.
"""
spmat = spmat.tocoo()
row = F.tensor(spmat.row, idtype)
col = F.tensor(spmat.col, idtype)
return row, col
def networkx2tensor(nx_graph, idtype, edge_id_attr_name='id'):
"""Function to convert a networkx graph to edge tensors.
Parameters
----------
nx_graph : nx.Graph
NetworkX graph.
idtype : int32, int64, optional
Integer ID type. Must be int32 or int64.
edge_id_attr_name : str, optional
Key name for edge ids in the NetworkX graph. If not found, we
will consider the graph not to have pre-specified edge ids. (Default: 'id')
Returns
-------
(Tensor, Tensor)
Edge tensors.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
# Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
# nx_graph.edges(data=True) returns src, dst, attr_dict
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = [0] * num_edges
dst = [0] * num_edges
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
src = F.tensor(src, idtype)
dst = F.tensor(dst, idtype)
return src, dst
def graphdata2tensors(data, idtype=None, bipartite=False):
"""Function to convert various types of data to edge tensors and infer
the number of nodes.
Parameters
----------
data : graph data
Various kinds of graph data.
idtype : int32, int64, optional
Integer ID type. If None, try infer from the data and if fail use
int64.
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.
Returns
-------
src : Tensor
Src nodes.
dst : Tensor
Dst nodes.
num_src : int
Number of source nodes
num_dst : int
Number of destination nodes.
"""
if idtype is None and not (isinstance(data, tuple) and F.is_tensor(data[0])):
# preferred default idtype is int64
# if data is tensor and idtype is None, infer the idtype from tensor
idtype = F.int64
if isinstance(data, tuple):
src, dst = F.tensor(data[0], idtype), F.tensor(data[1], idtype)
elif isinstance(data, list):
src, dst = elist2tensor(data, idtype)
elif isinstance(data, sp.sparse.spmatrix):
src, dst = scipy2tensor(data, idtype)
elif isinstance(data, nx.Graph):
if bipartite:
src, dst = networkxbipartite2tensors(data, idtype)
else:
src, dst = networkx2tensor(data, idtype)
else:
raise DGLError('Unsupported graph data type:', type(data))
infer_from_raw = infer_num_nodes(data, bipartite=bipartite)
if infer_from_raw is None:
num_src, num_dst = infer_num_nodes((src, dst), bipartite=bipartite)
else:
num_src, num_dst = infer_from_raw
return src, dst, num_src, num_dst
def networkxbipartite2tensors(nx_graph, idtype, edge_id_attr_name='id'):
"""Function to convert a networkx bipartite to edge tensors.
Parameters
----------
nx_graph : nx.Graph
NetworkX graph. It must follow the bipartite graph convention of networkx.
Each node has an attribute ``bipartite`` with values 0 and 1 indicating
which set it belongs to. Only edges from node set 0 to node set 1 are
added to the returned graph.
idtype : int32, int64, optional
Integer ID type. Must be int32 or int64.
edge_id_attr_name : str, optional
Key name for edge ids in the NetworkX graph. If not found, we
will consider the graph not to have pre-specified edge ids. (Default: 'id')
Returns
-------
(Tensor, Tensor)
Edge tensors.
"""
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
top_nodes = {n for n, d in nx_graph.nodes(data=True) if d['bipartite'] == 0}
bottom_nodes = set(nx_graph) - top_nodes
top_nodes = sorted(top_nodes)
bottom_nodes = sorted(bottom_nodes)
top_map = {n : i for i, n in enumerate(top_nodes)}
bottom_map = {n : i for i, n in enumerate(bottom_nodes)}
if nx_graph.number_of_edges() > 0:
has_edge_id = edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
else:
has_edge_id = False
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = [0] * num_edges
dst = [0] * num_edges
for u, v, attr in nx_graph.edges(data=True):
eid = attr[edge_id_attr_name]
src[eid] = top_map[u]
dst[eid] = bottom_map[v]
else:
src = []
dst = []
for e in nx_graph.edges:
if e[0] in top_map:
src.append(top_map[e[0]])
dst.append(bottom_map[e[1]])
src = F.tensor(src, dtype=idtype)
dst = F.tensor(dst, dtype=idtype)
return src, dst
def infer_num_nodes(data, bipartite=False):
"""Function for inferring the number of nodes.
Parameters
----------
data : graph data
Supported types are:
* Tensor pair (u, v)
* SciPy matrix
* NetworkX graph
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.
Returns
-------
num_src : int
Number of source nodes.
num_dst : int
Number of destination nodes.
or
None
If the inference failed.
"""
if isinstance(data, tuple) and len(data) == 2 and F.is_tensor(data[0]):
u, v = data
nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0
ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0
elif isinstance(data, sp.sparse.spmatrix):
nsrc, ndst = data.shape[0], data.shape[1]
elif isinstance(data, nx.Graph):
if data.number_of_nodes() == 0:
nsrc = ndst = 0
elif not bipartite:
nsrc = ndst = data.number_of_nodes()
else:
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc
else:
return None
if not bipartite:
nsrc = ndst = max(nsrc, ndst)
return nsrc, ndst
"""Utility module."""
"""Internal utilities."""
from __future__ import absolute_import, division
from collections.abc import Mapping, Iterable
......@@ -6,9 +6,9 @@ from collections import defaultdict
from functools import wraps
import numpy as np
from .base import DGLError, dgl_warning
from . import backend as F
from . import ndarray as nd
from ..base import DGLError, dgl_warning
from .. import backend as F
from .. import ndarray as nd
class InconsistentDtypeException(DGLError):
......@@ -684,3 +684,27 @@ class FlattenedDict(object):
k = self._group_keys[i]
j = idx - self._group_offsets[i]
return k, self._groups[k][j]
def compensate(ids, origin_ids):
"""computing the compensate set of ids from origin_ids
Note: ids should be a subset of origin_ids.
Any of ids and origin_ids can be non-consecutive,
and origin_ids should be sorted.
Example:
>>> ids = th.Tensor([0, 2, 4])
>>> origin_ids = th.Tensor([0, 1, 2, 4, 5])
>>> compensate(ids, origin_ids)
th.Tensor([1, 5])
"""
# trick here, eid_0 or nid_0 can be 0.
mask = F.scatter_row(origin_ids,
F.copy_to(F.tensor(0, dtype=F.int64),
F.context(origin_ids)),
F.copy_to(F.tensor(1, dtype=F.dtype(origin_ids)),
F.context(origin_ids)))
mask = F.scatter_row(mask,
ids,
F.full_1d(len(ids), 0, F.dtype(ids), F.context(ids)))
return F.tensor(F.nonzero_1d(mask), dtype=F.dtype(ids))
......@@ -266,7 +266,7 @@ class HeteroNodeView(object):
ntype = None
elif isinstance(key, tuple):
nodes, ntype = key
elif isinstance(key, str):
elif key is None or isinstance(key, str):
nodes = ALL
ntype = key
else:
......@@ -277,8 +277,10 @@ class HeteroNodeView(object):
def __call__(self, ntype=None):
"""Return the nodes."""
return F.arange(0, self._graph.number_of_nodes(ntype),
dtype=self._graph._idtype_str)
ntid = self._typeid_getter(ntype)
return F.copy_to(F.arange(0, self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype),
self._graph.device)
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
......@@ -337,6 +339,12 @@ class HeteroNodeDataView(MutableMapping):
'can not be iterated.'
return iter(self._graph._node_frames[self._ntid])
def keys(self):
return self._graph._node_frames[self._ntid].keys()
def values(self):
return self._graph._node_frames[self._ntid].values()
def __repr__(self):
if isinstance(self._ntype, list):
ret = {}
......@@ -366,6 +374,9 @@ class HeteroEdgeView(object):
raise DGLError('Currently only full slice ":" is supported')
edges = ALL
etype = None
elif key is None:
edges = ALL
etype = None
elif isinstance(key, tuple):
if len(key) == 3:
edges = ALL
......@@ -373,7 +384,7 @@ class HeteroEdgeView(object):
else:
edges = key
etype = None
elif isinstance(key, (str, tuple)):
elif isinstance(key, str):
edges = ALL
etype = key
else:
......@@ -444,6 +455,12 @@ class HeteroEdgeDataView(MutableMapping):
'can not be iterated.'
return iter(self._graph._edge_frames[self._etid])
def keys(self):
return self._graph._edge_frames[self._etid].keys()
def values(self):
return self._graph._edge_frames[self._etid].values()
def __repr__(self):
if isinstance(self._etype, list):
ret = {}
......
......@@ -63,6 +63,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
<< static_cast<int>(bits) << ".";
if (arr->dtype.bits == bits)
return arr;
if (arr.NumElements() == 0)
return NewIdArray(arr->shape[0], arr->ctx, bits);
IdArray ret;
ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, {
......@@ -76,20 +78,20 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_SAME_CONTEXT(lhs, rhs);
CHECK_SAME_DTYPE(lhs, rhs);
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "HStack", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::HStack<XPU, IdType>(lhs, rhs);
});
});
return ret;
}
IdArray NonZero(BoolArray bool_arr) {
IdArray ret;
ATEN_XPU_SWITCH(bool_arr->ctx.device_type, XPU, "NonZero", {
ATEN_ID_TYPE_SWITCH(bool_arr->dtype, IdType, {
ret = impl::NonZero<XPU, IdType>(bool_arr);
});
CHECK_EQ(lhs->shape[0], rhs->shape[0]);
auto device = runtime::DeviceAPI::Get(lhs->ctx);
const auto& ctx = lhs->ctx;
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
const int64_t len = lhs->shape[0];
ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
device->CopyDataFromTo(lhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), 0,
len * sizeof(IdType),
ctx, ctx, lhs->dtype, nullptr);
device->CopyDataFromTo(rhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), len * sizeof(IdType),
len * sizeof(IdType),
ctx, ctx, lhs->dtype, nullptr);
});
return ret;
}
......@@ -161,6 +163,22 @@ NDArray Scatter(NDArray array, IdArray indices) {
return ret;
}
void Scatter_(IdArray index, NDArray value, NDArray out) {
CHECK_SAME_DTYPE(value, out);
CHECK_SAME_CONTEXT(index, value);
CHECK_SAME_CONTEXT(index, out);
CHECK_EQ(value->shape[0], index->shape[0]);
if (index->shape[0] == 0)
return;
ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
impl::Scatter_<XPU, DType, IdType>(index, value, out);
});
});
});
}
NDArray Repeat(NDArray array, IdArray repeats) {
NDArray ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Repeat", {
......@@ -259,6 +277,16 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret;
}
IdArray NonZero(NDArray array) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
ATEN_ID_TYPE_SWITCH(array->dtype, DType, {
ret = impl::NonZero<XPU, DType>(array);
});
});
return ret;
}
std::string ToDebugString(NDArray array) {
std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
......@@ -300,7 +328,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
bool CSRHasDuplicate(CSRMatrix csr) {
bool ret = false;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRHasDuplicate", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRHasDuplicate", {
ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
});
return ret;
......@@ -353,23 +381,13 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret;
}
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, row, col);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret;
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
});
return ret;
......@@ -382,7 +400,7 @@ std::vector<NDArray> CSRGetDataAndIndices(
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetDataAndIndices", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetDataAndIndices", {
ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
});
return ret;
......@@ -443,7 +461,7 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceMatrix", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRSliceMatrix", {
ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
});
return ret;
......@@ -583,14 +601,6 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row)
return ret;
}
NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
ret = impl::COOGetData<XPU, IdType>(coo, row, col);
});
return ret;
}
std::vector<NDArray> COOGetDataAndIndices(
COOMatrix coo, NDArray rows, NDArray cols) {
std::vector<NDArray> ret;
......@@ -600,6 +610,14 @@ std::vector<NDArray> COOGetDataAndIndices(
return ret;
}
NDArray COOGetData(COOMatrix coo, NDArray rows, NDArray cols) {
NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
ret = impl::COOGetData<XPU, IdType>(coo, rows, cols);
});
return ret;
}
COOMatrix COOTranspose(COOMatrix coo) {
return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);
}
......@@ -971,6 +989,16 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
#endif // _WIN32
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DLDataType dtype = array->dtype;
dtype.code = kDLInt;
*rv = array.CreateView(shape, dtype, 0);
});
} // namespace aten
} // namespace dgl
......
......@@ -37,9 +37,6 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray array);
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index);
......@@ -52,6 +49,9 @@ IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);
template <DLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats);
......@@ -70,6 +70,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
template <DLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero);
template <DLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
......@@ -96,9 +99,6 @@ runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
......@@ -180,13 +180,13 @@ template <DLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo);
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/array_nonzero.cc
* \brief Array nonzero CPU implementation
*/
#include <dgl/array.h>
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
std::vector<int64_t> ret;
const IdType* data = array.Ptr<IdType>();
for (int64_t i = 0; i < array->shape[0]; ++i)
if (data[i] != 0)
ret.push_back(i);
return NDArray::FromVector(ret, array->ctx);
}
template IdArray NonZero<kDLCPU, int32_t>(IdArray);
template IdArray NonZero<kDLCPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -161,26 +161,6 @@ IdArray UnaryElewise(IdArray lhs) {
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// HStack /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L, DLContext{kDLCPU, 0}, arr1->dtype.bits);
const IdType* arr1_data = static_cast<IdType*>(arr1->data);
const IdType* arr2_data = static_cast<IdType*>(arr2->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < L; ++i) {
ret_data[i] = arr1_data[i];
ret_data[i + L] = arr2_data[i];
}
return ret;
}
template IdArray HStack<kDLCPU, int32_t>(IdArray arr1, IdArray arr2);
template IdArray HStack<kDLCPU, int64_t>(IdArray arr1, IdArray arr2);
///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename IdType>
......
......@@ -33,6 +33,26 @@ template NDArray Scatter<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
const DType* val = value.Ptr<DType>();
DType* outd = out.Ptr<DType>();
#pragma omp parallel for
for (int64_t i = 0; i < len; ++i)
outd[idx[i]] = val[i];
}
template void Scatter_<kDLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
......@@ -152,22 +152,60 @@ COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t);
///////////////////////////// COOGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
std::vector<IdType> ret_vec;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const IdType* data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
for (IdType i = 0; i < coo.row->shape[0]; ++i) {
if (coo_row_data[i] == row && coo_col_data[i] == col)
ret_vec.push_back(data ? data[i] : i);
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col Id array:" << rows << " " << cols;
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = rows.Ptr<IdType>();
const IdType* col_data = cols.Ptr<IdType>();
const IdType* coo_row = coo.row.Ptr<IdType>();
const IdType* coo_col = coo.col.Ptr<IdType>();
const IdType* data = COOHasData(coo) ? coo.data.Ptr<IdType>() : nullptr;
const int64_t nnz = coo.row->shape[0];
const int64_t retlen = std::max(rowlen, collen);
IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);
IdType* ret_data = ret.Ptr<IdType>();
// TODO(minjie): We might need to consider sorting the COO beforehand especially
// when the number of (row, col) pairs is large. Need more benchmarks to justify
// the choice.
if (coo.row_sorted) {
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
auto it = std::lower_bound(coo_row, coo_row + nnz, row_id);
for (; it < coo_row + nnz && *it == row_id; ++it) {
const auto idx = it - coo_row;
if (coo_col[idx] == col_id) {
ret_data[p] = data? data[idx] : idx;
break;
}
}
}
} else {
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
for (int64_t idx = 0; idx < nnz; ++idx) {
if (coo_row[idx] == row_id && coo_col[idx] == col_id) {
ret_data[p] = data? data[idx] : idx;
break;
}
}
}
}
return NDArray::FromVector(ret_vec);
return ret;
}
template NDArray COOGetData<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template NDArray COOGetData<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template IdArray COOGetData<kDLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDLCPU, int64_t>(COOMatrix, IdArray, IdArray);
///////////////////////////// COOGetDataAndIndices /////////////////////////////
......
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/spmat_op_impl.cc
* \brief Sparse matrix operator CPU implementation
* \file array/cpu/spmat_op_impl_csr.cc
* \brief CSR matrix operator CPU implementation
*/
#include <dgl/array.h>
#include <vector>
......@@ -164,30 +164,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
}
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
std::vector<IdType> ret_vec;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
if (csr.sorted) {
CollectDataFromSorted<XPU, IdType>(indices_data, data,
indptr_data[row], indptr_data[row + 1],
col, &ret_vec);
} else {
for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) {
if (indices_data[i] == col) {
ret_vec.push_back(data? data[i] : i);
}
}
}
return NDArray::FromVector(ret_vec, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
IdArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
......@@ -203,26 +180,45 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
std::vector<IdType> ret_vec;
const int64_t retlen = std::max(rowlen, collen);
IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);
IdType* ret_data = ret.Ptr<IdType>();
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
if (csr.sorted) {
CollectDataFromSorted<XPU, IdType>(indices_data, data,
indptr_data[row_id], indptr_data[row_id + 1],
col_id, &ret_vec);
} else {
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
if (indices_data[i] == col_id) {
ret_vec.push_back(data? data[i] : i);
// NOTE: In most cases, the input csr is already sorted. If not, we might need to
// consider sorting it especially when the number of (row, col) pairs is large.
// Need more benchmarks to justify the choice.
if (csr.sorted) {
// use binary search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
const IdType *start_ptr = indices_data + indptr_data[row_id];
const IdType *end_ptr = indices_data + indptr_data[row_id + 1];
auto it = std::lower_bound(start_ptr, end_ptr, col_id);
if (it != end_ptr && *it == col_id) {
const IdType idx = it - indices_data;
ret_data[p] = data? data[idx] : idx;
}
}
} else {
// linear search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1]; ++idx) {
if (indices_data[idx] == col_id) {
ret_data[p] = data? data[idx] : idx;
break;
}
}
}
}
return NDArray::FromVector(ret_vec, csr.data->ctx);
return ret;
}
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
......@@ -491,8 +487,6 @@ template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0];
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/array_nonzero.cc
* \brief Array nonzero CPU implementation
*/
#include <thrust/iterator/counting_iterator.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/device_vector.h>
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <typename IdType>
struct IsNonZero {
__device__ bool operator() (const IdType val) {
return val != 0;
}
};
template <DLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, array->ctx, 64);
thrust::device_ptr<IdType> in_data(array.Ptr<IdType>());
thrust::device_ptr<int64_t> out_data(ret.Ptr<int64_t>());
// TODO(minjie): should take control of the memory allocator.
// See PyTorch's implementation here:
// https://github.com/pytorch/pytorch/blob/1f7557d173c8e9066ed9542ada8f4a09314a7e17/
// aten/src/THC/generic/THCTensorMath.cu#L104
auto startiter = thrust::make_counting_iterator<int64_t>(0);
auto enditer = startiter + len;
auto indices_end = thrust::copy_if(thrust::cuda::par.on(thr_entry->stream),
startiter,
enditer,
in_data,
out_data,
IsNonZero<IdType>());
const int64_t num_nonzeros = indices_end - out_data;
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}
template IdArray NonZero<kDLGPU, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/cuda/array_scatter.cu
* \brief Array scatter GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <typename DType, typename IdType>
__global__ void _ScatterKernel(const IdType* index, const DType* value,
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[index[tx]] = value[tx];
tx += stride_x;
}
}
template <DLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
const DType* val = value.Ptr<DType>();
DType* outd = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
_ScatterKernel<<<nb, nt, 0, thr_entry->stream>>>(idx, val, len, outd);
}
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/spmat_op_impl.cu
* \brief Sparse matrix operator CPU implementation
* \file array/cuda/spmat_op_impl_csr.cu
* \brief CSR operator CPU implementation
*/
#include <dgl/array.h>
#include <vector>
......@@ -19,31 +19,31 @@ namespace impl {
/*!
* \brief Search adjacency list linearly for each (row, col) pair and
* write the matched position in the indices array to the output.
*
* write the data under the matched position in the indices array to the output.
*
* If there is no match, -1 is written.
* If there are multiple matches, only the first match is written.
* If the given data array is null, write the matched position to the output.
*/
template <typename IdType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices,
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride,
int64_t length, IdType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
int rpos = tx, cpos = tx;
while (tx < length) {
out[tx] = -1;
int rpos = tx * row_stride, cpos = tx * col_stride;
IdType v = -1;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
out[tx] = i;
v = (data)? data[i] : i;
break;
}
}
rpos += row_stride;
cpos += col_stride;
out[tx] = v;
tx += stride_x;
}
}
......@@ -59,9 +59,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
rows = rows.CopyTo(ctx);
cols = cols.CopyTo(ctx);
IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr
_LinearSearchKernel<<<1, 1, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
1, 1, 1,
out.Ptr<IdType>());
......@@ -85,9 +86,10 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr
_LinearSearchKernel<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen,
rst.Ptr<IdType>());
......@@ -97,6 +99,52 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
template NDArray CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
/*!
* \brief Check whether each row does not have any duplicate entries.
* Assume the CSR is sorted.
*/
template <typename IdType>
__global__ void _SegmentHasNoDuplicate(
const IdType* indptr, const IdType* indices,
int64_t num_rows, int8_t* flags) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < num_rows) {
bool f = true;
for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
f = (indices[i - 1] != indices[i]);
}
flags[tx] = static_cast<int8_t>(f);
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
if (!csr.sorted)
csr = CSRSort(csr);
const auto& ctx = csr.indptr->ctx;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
// be fine.
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt;
_SegmentHasNoDuplicate<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.num_rows, flags);
bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
device->FreeWorkspace(ctx, flags);
return !ret;
}
template bool CSRHasDuplicate<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLGPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
......@@ -211,14 +259,13 @@ __global__ void _SegmentCopyKernel(
const IdType* out_indptr, DType* out_data) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
int rpos = tx;
while (tx < length) {
int rpos = tx * row_stride;
const IdType r = row[rpos];
DType* out_buf = out_data + out_indptr[tx];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
*(out_buf++) = data? data[i] : i;
}
rpos += row_stride;
tx += stride_x;
}
}
......@@ -252,6 +299,249 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray CSRGetData(CSRMatrix csr, NDArray row, NDArray col) {
const int64_t rowlen = row->shape[0];
const int64_t collen = col->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
if (rstlen == 0)
return rst;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
// TODO(minjie): use binary search for sorted csr
_LinearSearchKernel<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen,
rst.Ptr<IdType>());
return rst;
}
template NDArray CSRGetData<kDLGPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLGPU, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
/*!
* \brief Generate a 0-1 mask for each index that hits the provided (row, col)
* index.
*
* Examples:
* Given a CSR matrix (with duplicate entries) as follows:
* [[0, 1, 2, 0, 0],
* [1, 0, 0, 0, 0],
* [0, 0, 1, 1, 0],
* [0, 0, 0, 0, 0]]
* Given rows: [0, 1], cols: [0, 2, 3]
* The result mask is: [0, 1, 1, 1, 0, 0]
*/
template <typename IdType>
__global__ void _SegmentMaskKernel(
const IdType* indptr, const IdType* indices,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride,
int64_t length, IdType* mask) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride, cpos = tx * col_stride;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
mask[i] = 1;
}
}
tx += stride_x;
}
}
/*!
* \brief Search for the insertion positions for needle in the hay.
*
* The hay is a list of sorted elements and the result is the insertion position
* of each needle so that the insertion still gives sorted order.
*
* It essentially perform binary search to find lower bound for each needle
* elements.
*/
template <typename IdType>
__global__ void _SortedSearchKernel(
const IdType* hay, int64_t hay_size,
const IdType* needles, int64_t num_needles,
IdType* pos) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < num_needles) {
const IdType ele = needles[tx];
// binary search
IdType lo = 0, hi = hay_size - 1;
while (lo < hi) {
IdType mid = (lo + hi) >> 1;
if (hay[mid] <= ele) {
lo = mid + 1;
} else {
hi = mid;
}
}
pos[tx] = (hay[hi] == ele)? hi : hi - 1;
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
const auto len = std::max(rowlen, collen);
if (len == 0)
return {NullArray(), NullArray(), NullArray()};
const auto& ctx = row->ctx;
const auto nbits = row->dtype.bits;
const int64_t nnz = csr.indices->shape[0];
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
_SegmentMaskKernel<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, len,
mask.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits);
if (idx->shape[0] == 0)
// No data. Return three empty arrays.
return {idx, idx, idx};
// Search for row index
IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
const int nt2 = cuda::FindNumThreads(idx->shape[0]);
const int nb2 = (idx->shape[0] + nt - 1) / nt;
_SortedSearchKernel<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.num_rows,
idx.Ptr<IdType>(), idx->shape[0],
ret_row.Ptr<IdType>());
// Column & data can be obtained by index select.
IdArray ret_col = IndexSelect(csr.indices, idx);
IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx;
return {ret_row, ret_col, ret_data};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRSliceMatrix /////////////////////////////
/*!
* \brief Generate a 0-1 mask for each index whose column is in the provided set.
* It also counts the number of masked values per row.
*/
template <typename IdType>
__global__ void _SegmentMaskColKernel(
const IdType* indptr, const IdType* indices, int64_t num_rows,
const IdType* col, int64_t col_len,
IdType* mask, IdType* count) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
// TODO(minjie): consider putting the col array in shared memory.
while (tx < num_rows) {
IdType cnt = 0;
for (IdType i = indptr[tx]; i < indptr[tx + 1]; ++i) {
const IdType cur_c = indices[i];
for (int64_t j = 0; j < col_len; ++j) {
if (cur_c == col[j]) {
mask[i] = 1;
++cnt;
break;
}
}
}
count[tx] = cnt;
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = rows->ctx;
const auto& dtype = rows->dtype;
const auto nbits = dtype.bits;
const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0];
if (new_nrows == 0 || new_ncols == 0)
return CSRMatrix(new_nrows, new_ncols,
Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx));
// First slice rows
csr = CSRSliceRows(csr, rows);
if (csr.indices->shape[0] == 0)
return CSRMatrix(new_nrows, new_ncols,
Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx));
// Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
// A count for how many masked values per row.
IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt;
_SegmentMaskColKernel<<<nb, nt, 0, thr_entry->stream>>>(
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows,
cols.Ptr<IdType>(), cols->shape[0],
mask.Ptr<IdType>(), count.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits);
if (idx->shape[0] == 0)
return CSRMatrix(new_nrows, new_ncols,
Full(0, new_nrows + 1, nbits, ctx),
NullArray(dtype, ctx), NullArray(dtype, ctx));
// Indptr needs to be adjusted according to the new nnz per row.
IdArray ret_indptr = CumSum(count, true);
// Column & data can be obtained by index select.
IdArray ret_col = IndexSelect(csr.indices, idx);
IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx;
// Relabel column
IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);
Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);
ret_col = IndexSelect(col_hash, ret_col);
return CSRMatrix(new_nrows, new_ncols, ret_indptr,
ret_col, ret_data);
}
template CSRMatrix CSRSliceMatrix<kDLGPU, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLGPU, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -70,16 +70,15 @@ void SpMM(const std::string& op, const std::string& reduce,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux,
SparseFormat format) {
std::vector<NDArray> out_aux) {
// TODO(zihao): format tuning
format = SparseFormat::kCSR;
SparseFormat format = graph->SelectFormat(0, csc_code);
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH(out->dtype, DType, "Feature data", {
if (format == SparseFormat::kCSR) {
if (format == SparseFormat::kCSC) {
SpMMCsr<XPU, IdType, DType>(
op, reduce, bcast, graph->GetCSCMatrix(0),
ufeat, efeat, out, out_aux);
......@@ -88,7 +87,7 @@ void SpMM(const std::string& op, const std::string& reduce,
op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux);
} else {
LOG(FATAL) << "SpMM only supports CSR and COO foramts";
LOG(FATAL) << "SpMM only supports CSC and COO foramts";
}
});
});
......@@ -102,11 +101,10 @@ void SDDMM(const std::string& op,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target,
SparseFormat format) {
int rhs_target) {
// TODO(zihao): format tuning
format = SparseFormat::kCOO;
const auto& bcast = CalcBcastOff(op, lhs, rhs);
SparseFormat format = graph->SelectFormat(0, coo_code);
const auto &bcast = CalcBcastOff(op, lhs, rhs);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
......@@ -150,7 +148,7 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
{0, 1, 2, 2, 2},
{U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}, SparseFormat::kAny);
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
......@@ -173,7 +171,7 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
{lhs_target, rhs_target, 1},
{lhs, rhs, out},
{"U_data", "E_data", "V_data"});
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target, SparseFormat::kAny);
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
});
} // namespace aten
......
......@@ -18,48 +18,52 @@ HeteroGraphPtr CreateHeteroGraph(
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format) {
IdArray row, IdArray col, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, restrict_format);
num_vtypes, num_src, num_dst, row, col, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, restrict_format);
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format);
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats);
auto ret = HeteroGraphPtr(new HeteroGraph(
unit_g->meta_graph(),
{unit_g}));
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(),
{unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format);
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, restrict_format);
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
......
......@@ -44,6 +44,10 @@ HeteroSubgraph EdgeSubgraphPreserveNodes(
HeteroSubgraph EdgeSubgraphNoPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) {
// TODO(minjie): In general, all relabeling should be separated with subgraph
// operations.
CHECK(hg->Context().device_type != kDLGPU)
<< "Edge subgraph with relabeling does not support GPU.";
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type.";
HeteroSubgraph ret;
......@@ -129,7 +133,6 @@ InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i];
......@@ -181,7 +184,6 @@ HeteroGraph::HeteroGraph(
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else
num_verts_per_type_ = num_nodes_per_type;
HeteroGraphSanityCheck(meta_graph, rel_graphs);
relation_graphs_ = CastToUnitGraphs(rel_graphs);
}
......@@ -260,11 +262,11 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::GetGraphInFormat(SparseFormat restrict_format) const {
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
format_rels[etype] = relgraph->GetGraphInFormat(restrict_format);
format_rels[etype] = relgraph->GetGraphInFormat(formats);
}
return HeteroGraphPtr(new HeteroGraph(
meta_graph_, format_rels, NumVerticesPerType()));
......@@ -284,9 +286,8 @@ template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0;
std::vector<IdType> result_src, result_dst;
std::vector<dgl_type_t> induced_srctype, induced_etype, induced_dsttype;
std::vector<IdType> induced_srcid, induced_eid, induced_dstid;
std::vector<dgl_type_t> induced_srctype, induced_dsttype;
std::vector<IdType> induced_srcid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
......@@ -337,6 +338,13 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
}
}
// TODO(minjie): Using concat operations cause many fragmented memory.
// Need to optimize it in the future.
std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
src_arrs.reserve(etypes.size());
dst_arrs.reserve(etypes.size());
eid_arrs.reserve(etypes.size());
induced_etypes.reserve(etypes.size());
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
......@@ -346,36 +354,34 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
EdgeArray edges = Edges(etype);
size_t num_edges = NumEdges(etype);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
const IdType* edges_eid_data = static_cast<const IdType*>(edges.id->data);
// TODO(gq) Use concat?
for (size_t i = 0; i < num_edges; ++i) {
result_src.push_back(edges_src_data[i] + srctype_offset);
result_dst.push_back(edges_dst_data[i] + dsttype_offset);
induced_etype.push_back(etype);
induced_eid.push_back(edges_eid_data[i]);
}
src_arrs.push_back(edges.src + srctype_offset);
dst_arrs.push_back(edges.dst + dsttype_offset);
eid_arrs.push_back(edges.id);
induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context()));
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
aten::Concat(src_arrs),
aten::Concat(dst_arrs));
// Sanity check
CHECK_EQ(gptr->Context(), Context());
CHECK_EQ(gptr->NumBits(), NumBits());
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->induced_srctype = aten::VecToIdArray(induced_srctype);
result->induced_srctype_set = aten::VecToIdArray(srctype_set);
result->induced_srcid = aten::VecToIdArray(induced_srcid);
result->induced_etype = aten::VecToIdArray(induced_etype);
result->induced_etype_set = aten::VecToIdArray(etypes);
result->induced_eid = aten::VecToIdArray(induced_eid);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype);
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set);
result->induced_dstid = aten::VecToIdArray(induced_dstid);
result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
result->induced_etype = aten::Concat(induced_etypes);
result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
result->induced_eid = aten::Concat(eid_arrs);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context());
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context());
result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
return FlattenedHeteroGraphPtr(result);
}
......
......@@ -102,8 +102,12 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->EdgeId(0, src, dst);
}
EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIds(0, src, dst);
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);
}
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
......@@ -183,18 +187,16 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->GetCSRMatrix(0);
}
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_format);
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);
}
std::string GetRestrictFormat() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return std::string("");
dgl_format_code_t GetAllowedFormats() const override {
return GetRelationGraph(0)->GetAllowedFormats();
}
dgl_format_code_t GetFormatInUse() const override {
LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
return 0;
dgl_format_code_t GetCreatedFormats() const override {
return GetRelationGraph(0)->GetCreatedFormats();
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
......@@ -202,7 +204,7 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;
HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
......
......@@ -10,6 +10,7 @@
#include "../c_api_common.h"
#include "./heterograph.h"
#include "unit_graph.h"
using namespace dgl::runtime;
......@@ -26,8 +27,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
SparseFormat restrict_format = ParseSparseFormat(args[5]);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, restrict_format);
List<Value> formats = args[5];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, code);
*rv = HeteroGraphRef(hgptr);
});
......@@ -39,9 +46,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
IdArray indptr = args[3];
IdArray indices = args[4];
IdArray edge_ids = args[5];
SparseFormat restrict_format = ParseSparseFormat(args[6]);
auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids,
restrict_format);
List<Value> formats = args[6];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
*rv = HeteroGraphRef(hgptr);
});
......@@ -242,16 +254,26 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
*rv = hg->EdgeId(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIds")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
IdArray dst = args[3];
const auto& ret = hg->EdgeIds(etype, src, dst);
const auto& ret = hg->EdgeIdsAll(etype, src, dst);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
IdArray dst = args[3];
*rv = hg->EdgeIdsOne(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
......@@ -537,37 +559,52 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRestrictFormat")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetRestrictFormat();
List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats();
for (auto format : CodeToSparseFormats(code)) {
format_list.push_back(
Value(MakeValue(ToStringSparseFormat(format))));
}
*rv = format_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatInUse")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAllowedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
*rv = hg->GetRelationGraph(etype)->GetFormatInUse();
List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
for (auto format : CodeToSparseFormats(code)) {
format_list.push_back(
Value(MakeValue(ToStringSparseFormat(format))));
}
*rv = format_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRequestFormat")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const std::string sparse_format = args[1];
dgl_type_t etype = args[2];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
bg->GetFormat(ParseSparseFormat(sparse_format));
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
for (auto format : CodeToSparseFormats(code))
bg->GetFormat(format);
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const std::string restrict_format = args[1];
auto hgptr = hg->GetGraphInFormat(ParseSparseFormat(restrict_format));
List<Value> formats = args[1];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
auto hgptr = hg->GetGraphInFormat(
SparseFormatsToCode(formats_vec));
*rv = HeteroGraphRef(hgptr);
});
......
......@@ -200,7 +200,7 @@ IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid vertex: " << src;
CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
return aten::CSRGetData(adj_, src, dst);
return aten::CSRGetAllData(adj_, src, dst);
}
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment