Unverified Commit f4608c22 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[CUDA][Kernel] A bunch of int64 kernels for COO and CSR (#1883)

* COO sort

* COOToCSR

* CSR2COO

* CSRSort; CSRTranspose

* pass all CSR tests

* lint

* remove int32 conversion

* fix tensorflow nn tests

* turn on CI

* fix

* addreess comments
parent 5b515cf6
......@@ -232,9 +232,7 @@ pipeline {
stages {
stage("Unit test") {
steps {
// TODO(minjie): tmp disabled
//unit_test_linux("tensorflow", "gpu")
sh "echo skipped"
unit_test_linux("tensorflow", "gpu")
}
}
}
......
......@@ -90,16 +90,19 @@ IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs);
IdArray Div(IdArray lhs, IdArray rhs);
IdArray Mod(IdArray lhs, IdArray rhs);
IdArray Add(IdArray lhs, int64_t rhs);
IdArray Sub(IdArray lhs, int64_t rhs);
IdArray Mul(IdArray lhs, int64_t rhs);
IdArray Div(IdArray lhs, int64_t rhs);
IdArray Mod(IdArray lhs, int64_t rhs);
IdArray Add(int64_t lhs, IdArray rhs);
IdArray Sub(int64_t lhs, IdArray rhs);
IdArray Mul(int64_t lhs, IdArray rhs);
IdArray Div(int64_t lhs, IdArray rhs);
IdArray Mod(int64_t lhs, IdArray rhs);
IdArray Neg(IdArray array);
......@@ -304,6 +307,17 @@ IdArray CumSum(IdArray array, bool prepend_zero = false);
*/
IdArray NonZero(NDArray array);
/*!
* \brief Sort the ID vector in ascending order.
*
* It performs both sort and arg_sort (returning the sorted index). The sorted index
* is always in int64.
*
* \param array Input array.
* \return A pair of arrays: sorted values and sorted index to the original position.
*/
std::pair<IdArray, IdArray> Sort(IdArray array);
/*!
* \brief Return a string that prints out some debug information.
*/
......
......@@ -603,14 +603,18 @@ dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1,
......
......@@ -148,10 +148,7 @@ def graph(data,
g = create_from_edges(u, v, ntype, etype, ntype, urange, vrange,
validate, formats=formats)
if device is None:
return utils.to_int32_graph_if_on_gpu(g)
else:
return g.to(device)
return g.to(device)
def bipartite(data,
utype='_U', etype='_E', vtype='_V',
......@@ -300,10 +297,7 @@ def bipartite(data,
u, v, utype, etype, vtype, urange, vrange, validate,
formats=formats)
if device is None:
return utils.to_int32_graph_if_on_gpu(g)
else:
return g.to(device)
return g.to(device)
def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
"""Create a heterograph from graphs representing connections of each relation.
......
......@@ -4450,7 +4450,7 @@ class DGLHeteroGraph(object):
device(type='cpu')
"""
if device is None or self.device == device:
return utils.to_int32_graph_if_on_gpu(self)
return self
ret = copy.copy(self)
......@@ -4481,8 +4481,6 @@ class DGLHeteroGraph(object):
for k, num in self._batch_num_edges.items()}
ret._batch_num_edges = new_bne
ret = utils.to_int32_graph_if_on_gpu(ret)
return ret
def cpu(self):
......
......@@ -11,7 +11,7 @@ __all__ = ['edge_softmax']
def edge_softmax_real(graph, score, eids=ALL):
"""Edge Softmax function"""
if not is_all(eids):
graph = graph.edge_subgraph(tf.cast(eids, graph.idtype))
graph = graph.edge_subgraph(tf.cast(eids, graph.idtype), preserve_nodes=True)
gidx = graph._graph
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
......
......@@ -2,9 +2,8 @@
# pylint: disable=invalid-name
from __future__ import absolute_import, division
from ..base import DGLError, dgl_warning
from ..base import DGLError
from .. import backend as F
from .internal import to_dgl_context
def prepare_tensor(g, data, name):
"""Convert the data to ID tensor and check its ID type and context.
......@@ -129,14 +128,3 @@ def check_all_same_schema(feat_dict_list, keys, name):
' and feature size, but got\n\t{} {}\nand\n\t{} {}.'.format(
name, k, F.dtype(t1), F.shape(t1)[1:],
F.dtype(t2), F.shape(t2)[1:]))
def to_int32_graph_if_on_gpu(g):
"""Convert to int32 graph if the input graph is on GPU."""
# device_type 2 is an internal code for GPU
if to_dgl_context(g.device).device_type == 2 and g.idtype == F.int64:
dgl_warning('Automatically cast a GPU int64 graph to int32.\n'
' To suppress the warning, call DGLGraph.int() first\n'
' or specify the ``device`` argument when creating the graph.')
return g.int()
else:
return g
......@@ -46,6 +46,13 @@ struct Div {
}
};
struct Mod {
template <typename T>
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 % t2;
}
};
struct GT {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
......
......@@ -287,6 +287,20 @@ IdArray NonZero(NDArray array) {
return ret;
}
std::pair<IdArray, IdArray> Sort(IdArray array) {
if (array.NumElements() == 0) {
IdArray idx = NewIdArray(0, array->ctx, 64);
return std::make_pair(array, idx);
}
std::pair<IdArray, IdArray> ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::Sort<XPU, IdType>(array);
});
});
return ret;
}
std::string ToDebugString(NDArray array) {
std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
......
......@@ -70,6 +70,7 @@ BINARY_ELEMENT_OP(Add, Add)
BINARY_ELEMENT_OP(Sub, Sub)
BINARY_ELEMENT_OP(Mul, Mul)
BINARY_ELEMENT_OP(Div, Div)
BINARY_ELEMENT_OP(Mod, Mod)
BINARY_ELEMENT_OP(GT, GT)
BINARY_ELEMENT_OP(LT, LT)
BINARY_ELEMENT_OP(GE, GE)
......@@ -81,6 +82,7 @@ BINARY_ELEMENT_OP_L(Add, Add)
BINARY_ELEMENT_OP_L(Sub, Sub)
BINARY_ELEMENT_OP_L(Mul, Mul)
BINARY_ELEMENT_OP_L(Div, Div)
BINARY_ELEMENT_OP_L(Mod, Mod)
BINARY_ELEMENT_OP_L(GT, GT)
BINARY_ELEMENT_OP_L(LT, LT)
BINARY_ELEMENT_OP_L(GE, GE)
......@@ -92,6 +94,7 @@ BINARY_ELEMENT_OP_R(Add, Add)
BINARY_ELEMENT_OP_R(Sub, Sub)
BINARY_ELEMENT_OP_R(Mul, Mul)
BINARY_ELEMENT_OP_R(Div, Div)
BINARY_ELEMENT_OP_R(Mod, Mod)
BINARY_ELEMENT_OP_R(GT, GT)
BINARY_ELEMENT_OP_R(LT, LT)
BINARY_ELEMENT_OP_R(GE, GE)
......@@ -117,6 +120,9 @@ NDArray operator * (const NDArray& lhs, const NDArray& rhs) {
NDArray operator / (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator + (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Add(lhs, rhs);
}
......@@ -129,6 +135,9 @@ NDArray operator * (const NDArray& lhs, int64_t rhs) {
NDArray operator / (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator + (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs);
}
......@@ -141,6 +150,9 @@ NDArray operator * (int64_t lhs, const NDArray& rhs) {
NDArray operator / (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator - (const NDArray& array) {
return dgl::aten::Neg(array);
}
......
......@@ -46,6 +46,9 @@ DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);
......
......@@ -60,6 +60,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
......@@ -70,6 +71,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
......@@ -94,6 +96,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
......@@ -104,6 +107,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
......@@ -128,6 +132,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
......@@ -138,6 +143,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/array_sort.cc
* \brief Array sort CPU implementation
*/
#include <dgl/array.h>
#ifdef PARALLEL_ALGORITHMS
#include <parallel/algorithm>
#endif
#include <algorithm>
#include <iterator>
namespace {
template <typename V1, typename V2>
struct PairRef {
PairRef() = delete;
PairRef(const PairRef& other) = default;
PairRef(PairRef&& other) = default;
PairRef(V1 *const r, V2 *const c)
: row(r), col(c) {}
PairRef& operator=(const PairRef& other) {
*row = *other.row;
*col = *other.col;
return *this;
}
PairRef& operator=(const std::pair<V1, V2>& val) {
*row = std::get<0>(val);
*col = std::get<1>(val);
return *this;
}
operator std::pair<V1, V2>() const {
return std::make_pair(*row, *col);
}
void Swap(const PairRef& other) const {
std::swap(*row, *other.row);
std::swap(*col, *other.col);
}
V1 *row;
V2 *col;
};
using std::swap;
template <typename V1, typename V2>
void swap(const PairRef<V1, V2>& r1, const PairRef<V1, V2>& r2) {
r1.Swap(r2);
}
template <typename V1, typename V2>
struct PairIterator : public std::iterator<std::random_access_iterator_tag,
std::pair<V1, V2>,
std::ptrdiff_t,
std::pair<V1*, V2*>,
PairRef<V1, V2>> {
PairIterator() = default;
PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default;
PairIterator(V1 *r, V2 *c): row(r), col(c) {}
PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default;
bool operator==(const PairIterator& other) const {
return row == other.row;
}
bool operator!=(const PairIterator& other) const {
return row != other.row;
}
bool operator<(const PairIterator& other) const {
return row < other.row;
}
bool operator>(const PairIterator& other) const {
return row > other.row;
}
bool operator<=(const PairIterator& other) const {
return row <= other.row;
}
bool operator>=(const PairIterator& other) const {
return row >= other.row;
}
PairIterator& operator+=(const std::ptrdiff_t& movement) {
row += movement;
col += movement;
return *this;
}
PairIterator& operator-=(const std::ptrdiff_t& movement) {
row -= movement;
col -= movement;
return *this;
}
PairIterator& operator++() {
return operator+=(1);
}
PairIterator& operator--() {
return operator-=(1);
}
PairIterator operator++(int) {
PairIterator ret(*this);
operator++();
return ret;
}
PairIterator operator--(int) {
PairIterator ret(*this);
operator--();
return ret;
}
PairIterator operator+(const std::ptrdiff_t& movement) const {
PairIterator ret(*this);
ret += movement;
return ret;
}
PairIterator operator-(const std::ptrdiff_t& movement) const {
PairIterator ret(*this);
ret -= movement;
return ret;
}
std::ptrdiff_t operator-(const PairIterator& other) const {
return row - other.row;
}
PairRef<V1, V2> operator*() const {
return PairRef<V1, V2>(row, col);
}
PairRef<V1, V2> operator*() {
return PairRef<V1, V2>(row, col);
}
V1 *row;
V2 *col;
};
} // namespace
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array) {
const int64_t nitem = array->shape[0];
IdArray val = array.Clone();
IdArray idx = aten::Range(0, nitem, 64, array->ctx);
IdType* val_data = val.Ptr<IdType>();
int64_t* idx_data = idx.Ptr<int64_t>();
typedef std::pair<IdType, int64_t> Pair;
#ifdef PARALLEL_ALGORITHMS
__gnu_parallel::sort(
#else
std::sort(
#endif
PairIterator<IdType, int64_t>(val_data, idx_data),
PairIterator<IdType, int64_t>(val_data, idx_data) + nitem,
[] (const Pair& a, const Pair& b) {
return std::get<0>(a) < std::get<0>(b);
});
return std::make_pair(val, idx);
}
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray);
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -45,6 +45,7 @@ template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
......@@ -55,6 +56,7 @@ template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
......@@ -92,6 +94,7 @@ template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
......@@ -102,6 +105,7 @@ template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
......@@ -140,6 +144,7 @@ template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
......@@ -150,6 +155,7 @@ template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/array_sort.cu
* \brief Array sort GPU implementation
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array) {
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t nitems = array->shape[0];
IdArray orig_idx = Range(0, nitems, 64, ctx);
IdArray sorted_array = NewIdArray(nitems, ctx, array->dtype.bits);
IdArray sorted_idx = NewIdArray(nitems, ctx, 64);
const IdType* keys_in = array.Ptr<IdType>();
const int64_t* values_in = orig_idx.Ptr<int64_t>();
IdType* keys_out = sorted_array.Ptr<IdType>();
int64_t* values_out = sorted_idx.Ptr<int64_t>();
// Allocate workspace
size_t workspace_size = 0;
cub::DeviceRadixSort::SortPairs(nullptr, workspace_size,
keys_in, keys_out, values_in, values_out, nitems);
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
cub::DeviceRadixSort::SortPairs(workspace, workspace_size,
keys_in, keys_out, values_in, values_out, nitems);
device->FreeWorkspace(ctx, workspace);
return std::make_pair(sorted_array, sorted_idx);
}
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -5,6 +5,7 @@
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
......@@ -15,7 +16,12 @@ namespace impl {
template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
CHECK(sizeof(IdType) == 4) << "CUDA COOToCSR does not support int64.";
LOG(FATAL) << "Unreachable code.";
return {};
}
template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
......@@ -32,6 +38,7 @@ CSRMatrix COOToCSR(COOMatrix coo) {
}
if (!row_sorted) {
coo = COOSort(coo);
col_sorted = coo.col_sorted;
}
const int64_t nnz = coo.row->shape[0];
......@@ -56,10 +63,86 @@ CSRMatrix COOToCSR(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted);
}
/*!
* \brief Search for the insertion positions for needle in the hay.
*
* The hay is a list of sorted elements and the result is the insertion position
* of each needle so that the insertion still gives sorted order.
*
* It essentially perform binary search to find upper bound for each needle
* elements.
*
* For example:
* hay = [0, 0, 1, 2, 2]
* needle = [0, 1, 2, 3]
* then,
* out = [2, 3, 5, 5]
*/
template <typename IdType>
__global__ void _SortedSearchKernelUpperBound(
const IdType* hay, int64_t hay_size,
const IdType* needles, int64_t num_needles,
IdType* pos) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < num_needles) {
const IdType ele = needles[tx];
// binary search
IdType lo = 0, hi = hay_size;
while (lo < hi) {
IdType mid = (lo + hi) >> 1;
if (hay[mid] <= ele) {
lo = mid + 1;
} else {
hi = mid;
}
}
pos[tx] = lo;
tx += stride_x;
}
}
template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted;
if (!row_sorted) {
// It is possible that the flag is simply not set (default value is false),
// so we still perform a linear scan to check the flag.
std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
}
if (!row_sorted) {
coo = COOSort(coo);
col_sorted = coo.col_sorted;
}
const int64_t nnz = coo.row->shape[0];
// TODO(minjie): Many of our current implementation assumes that CSR must have
// a data array. This is a temporary workaround. Remove this after:
// - The old immutable graph implementation is deprecated.
// - The old binary reduce kernel is deprecated.
if (!COOHasData(coo))
coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);
IdArray rowids = Range(0, coo.num_rows, nbits, ctx);
const int nt = cuda::FindNumThreads(coo.num_rows);
const int nb = (coo.num_rows + nt - 1) / nt;
IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
_SortedSearchKernelUpperBound<<<nb, nt, 0, thr_entry->stream>>>(
coo.row.Ptr<int64_t>(), nnz,
rowids.Ptr<int64_t>(), coo.num_rows,
indptr.Ptr<int64_t>() + 1);
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, coo.col, coo.data, col_sorted);
}
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -18,10 +18,14 @@ namespace impl {
template <DLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
LOG(FATAL) << "Unreachable codes";
}
template <>
void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column) {
// TODO(minjie): Current implementation is based on cusparse which only supports
// int32_t. To support int64_t, we could use the Radix sort algorithm provided
// by CUB.
CHECK(sizeof(IdType) == 4) << "CUDA COOSort does not support int64.";
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo->row->ctx);
// allocate cusparse handle if needed
......@@ -63,7 +67,7 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
if (sort_column) {
// First create a row indptr array and then call csrsort
int32_t* indptr = static_cast<int32_t*>(
device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(IdType)));
device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(int32_t)));
CUSPARSE_CALL(cusparseXcoo2csr(
thr_entry->cusparse_handle,
row_ptr,
......@@ -101,6 +105,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
coo->col_sorted = sort_column;
}
template <>
void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column) {
// Always sort the COO to be both row and column sorted.
IdArray pos = coo->row * coo->num_cols + coo->col;
const auto& sorted = Sort(pos);
coo->row = sorted.first / coo->num_cols;
coo->col = sorted.first % coo->num_cols;
if (aten::COOHasData(*coo))
coo->data = IndexSelect(coo->data, sorted.second);
else
coo->data = AsNumBits(sorted.second, coo->row->dtype.bits);
coo->row_sorted = coo->col_sorted = true;
}
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);
......
......@@ -5,6 +5,7 @@
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
......@@ -15,7 +16,12 @@ namespace impl {
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
CHECK(sizeof(IdType) == 4) << "CUDA CSRToCOO does not support int64.";
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
......@@ -41,12 +47,72 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
true, csr.sorted);
}
/*!
* \brief Repeat elements
* \param val Value to repeat
* \param repeats Number of repeats for each value
* \param pos The position of the output buffer to write the value.
* \param out Output buffer.
* \param length Number of values
*
* For example:
* val = [3, 0, 1]
* repeats = [1, 0, 2]
* pos = [0, 1, 1] # write to output buffer position 0, 1, 1
* then,
* out = [3, 1, 1]
*/
template <typename DType, typename IdType>
__global__ void _RepeatKernel(
const DType* val, const IdType* repeats, const IdType* pos,
DType* out, int64_t length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
IdType off = pos[tx];
const IdType rep = repeats[tx];
const DType v = val[tx];
for (IdType i = 0; i < rep; ++i) {
out[off + i] = v;
}
tx += stride_x;
}
}
template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
const int64_t nnz = csr.indices->shape[0];
const auto nbits = csr.indptr->dtype.bits;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
IdArray rowids = Range(0, csr.num_rows, nbits, ctx);
IdArray row_nnz = CSRGetRowNNZ(csr, rowids);
IdArray ret_row = NewIdArray(nnz, ctx, nbits);
const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt;
_RepeatKernel<<<nb, nt, 0, thr_entry->stream>>>(
rowids.Ptr<int64_t>(), row_nnz.Ptr<int64_t>(),
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
csr.num_rows);
return COOMatrix(csr.num_rows, csr.num_cols,
ret_row, csr.indices, csr.data,
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<XPU, IdType>(csr);
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
......@@ -85,6 +151,26 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
// The row and column field have already been reordered according
// to data, thus the data field will be deprecated.
coo.data = aten::NullArray();
coo.row_sorted = false;
coo.col_sorted = false;
return coo;
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
const auto& sorted = Sort(coo.data);
coo.row = IndexSelect(coo.row, sorted.second);
coo.col = IndexSelect(coo.col, sorted.second);
// The row and column field have already been reordered according
// to data, thus the data field will be deprecated.
coo.data = aten::NullArray();
coo.row_sorted = false;
coo.col_sorted = false;
return coo;
}
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/csr_sort.cc
* \brief Sort COO index
* \brief Sort CSR index
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
......@@ -56,7 +57,11 @@ template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
CHECK(sizeof(IdType) == 4) << "CUDA CSRSort_ does not support int64.";
LOG(FATAL) << "Unreachable codes";
}
template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
// allocate cusparse handle if needed
......@@ -100,6 +105,43 @@ void CSRSort_(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace);
}
template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
const auto& ctx = csr->indptr->ctx;
const int64_t nnz = csr->indices->shape[0];
const auto nbits = csr->indptr->dtype.bits;
if (!aten::CSRHasData(*csr))
csr->data = aten::Range(0, nnz, nbits, ctx);
IdArray new_indices = csr->indices.Clone();
IdArray new_data = csr->data.Clone();
const int64_t* offsets = csr->indptr.Ptr<int64_t>();
const int64_t* key_in = csr->indices.Ptr<int64_t>();
int64_t* key_out = new_indices.Ptr<int64_t>();
const int64_t* value_in = csr->data.Ptr<int64_t>();
int64_t* value_out = new_data.Ptr<int64_t>();
// Allocate workspace
size_t workspace_size = 0;
cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1);
void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute
cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
key_in, key_out, value_in, value_out,
nnz, csr->num_rows, offsets, offsets + 1);
csr->sorted = true;
csr->indices = new_indices;
csr->data = new_data;
}
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);
......
......@@ -15,7 +15,12 @@ namespace impl {
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
CHECK(sizeof(IdType) == 4) << "CUDA CSR2CSC does not support int64.";
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
......@@ -82,6 +87,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
false);
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
}
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment