Unverified Commit d6d517bb authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Kernel] CUDA CSR2COO COOSort COO2CSR (#1620)



* add cuda source

* moving codes from kernel2 branch

* operator overloading

* Better error message for unsupported device

* fix c tests

* coo sort using cusparse

* move test_rpc to distributed

* lint

* address comments and add utests
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 61f007c4
......@@ -890,15 +890,45 @@ IdArray VecToIdArray(const std::vector<T>& vec,
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, ...) do { \
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
/*
* Dispatch according to device:
*
* XXX(minjie): temporary macro that allows CUDA operator
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
constexpr auto XPU = kDLGPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
#else // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA
/*
* Dispatch according to integral type (either int32 or int64):
*
......@@ -1011,17 +1041,17 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} while (0)
// Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, { \
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
......
......@@ -17,15 +17,8 @@
#include "serializer.h"
#include "shared_mem.h"
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
// forward declaration
inline std::ostream& operator << (std::ostream& os, DGLType t);
namespace dgl {
......@@ -210,6 +203,12 @@ class NDArray {
* \brief Get the size of the array in the number of bytes.
*/
size_t GetSize() const;
/*!
* \brief Get the number of elements in this array.
*/
int64_t NumElements() const;
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
......@@ -464,6 +463,110 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
return true;
}
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle";
case kDGLType: return "DGLType";
case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
/*!
* \brief Convert device type code to its name
* \param device_type The device type code.
* \return The name of the device.
*/
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
switch (device_type) {
case kDLCPU: return "cpu";
case kDLGPU: return "cuda";
case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
default: LOG(FATAL) << "Unknown device type code="
<< static_cast<int>(device_type); return "";
}
}
/*!
* \brief convert a string to DGL type.
* \param s The string to be converted.
* \return The corresponding dgl type.
*/
inline DGLType String2DGLType(std::string s) {
DGLType t;
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
}
return t;
}
/*!
* \brief convert a DGL type to string.
* \param t The type to be converted.
* \return The corresponding dgl type in string.
*/
inline std::string DGLType2String(DGLType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
return os.str();
#else
std::string repr = "";
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
}
// macro to check type code.
#define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
} // namespace runtime
} // namespace dgl
......@@ -472,4 +575,46 @@ namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc
///////////////// Operator overloading for DLDataType /////////////////
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two data types are different.*/
inline bool operator != (const DLDataType& ty1, const DLDataType& ty2) {
return !(ty1 == ty2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, DGLType t) {
os << dgl::runtime::TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
#endif
///////////////// Operator overloading for DLContext /////////////////
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
/*! \brief Check whether two device contexts are different.*/
inline bool operator != (const DLContext& ctx1, const DLContext& ctx2) {
return !(ctx1 == ctx2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id;
}
#endif
#endif // DGL_RUNTIME_NDARRAY_H_
......@@ -295,32 +295,6 @@ class DGLArgs {
inline DGLArgValue operator[](int i) const;
};
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
/*!
* \brief convert a string to DGL type.
* \param s The string to be converted.
* \return The corresponding dgl type.
*/
inline DGLType String2DGLType(std::string s);
/*!
* \brief convert a DGL type to string.
* \param t The type to be converted.
* \return The corresponding dgl type in string.
*/
inline std::string DGLType2String(DGLType t);
// macro to check type code.
#define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is dgl extension type.
*
......@@ -826,83 +800,6 @@ class DGLRetValue : public DGLPODValue_ {
};
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle";
case kDGLType: return "DGLType";
case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DGLType t) { // NOLINT(*)
os << TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
#endif
inline std::string DGLType2String(DGLType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
return os.str();
#else
std::string repr = "";
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
}
inline DGLType String2DGLType(std::string s) {
DGLType t;
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
}
return t;
}
inline DGLArgValue DGLArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
......
......@@ -900,7 +900,7 @@ class HeteroGraphIndex(ObjectBase):
HeteroGraphIndex
"""
g = self.get_relation_graph(etype)
return g.asbits(self.bits_needed(etype or 0)).copy_to(ctx)
return g.copy_to(ctx).asbits(self.bits_needed(etype or 0))
def get_csr_shuffle_order(self, etype):
"""Return the edge shuffling order when a coo graph is converted to csr format
......
......@@ -27,7 +27,7 @@ IdArray Clone(IdArray arr) {
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
if (nbits == 32) {
ret = impl::Range<XPU, int32_t>(low, high, ctx);
} else if (nbits == 64) {
......@@ -41,7 +41,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
ATEN_XPU_SWITCH(ctx.device_type, XPU, "Full", {
if (nbits == 32) {
ret = impl::Full<XPU, int32_t>(val, length, ctx);
} else if (nbits == 64) {
......@@ -54,8 +54,13 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
}
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64)
<< "Invalid ID type. Must be int32 or int64, but got int"
<< static_cast<int>(bits) << ".";
if (arr->dtype.bits == bits)
return arr;
IdArray ret;
ATEN_XPU_SWITCH(arr->ctx.device_type, XPU, {
ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, {
ret = impl::AsNumBits<XPU, IdType>(arr, bits);
});
......@@ -67,7 +72,7 @@ IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Add", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
......@@ -79,7 +84,7 @@ IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
......@@ -91,7 +96,7 @@ IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Mul", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
......@@ -103,7 +108,7 @@ IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
......@@ -113,7 +118,7 @@ IdArray Div(IdArray lhs, IdArray rhs) {
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Add", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
......@@ -123,7 +128,7 @@ IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
......@@ -133,7 +138,7 @@ IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Mul", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
......@@ -143,7 +148,7 @@ IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
......@@ -157,7 +162,7 @@ IdArray Add(dgl_id_t lhs, IdArray rhs) {
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
......@@ -171,7 +176,7 @@ IdArray Mul(dgl_id_t lhs, IdArray rhs) {
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
......@@ -181,7 +186,7 @@ IdArray Div(dgl_id_t lhs, IdArray rhs) {
BoolArray LT(IdArray lhs, dgl_id_t rhs) {
BoolArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "LT", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::LT>(lhs, rhs);
});
......@@ -193,7 +198,7 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "HStack", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::HStack<XPU, IdType>(lhs, rhs);
});
......@@ -204,7 +209,7 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret;
// TODO(BarclayII): check if array and index match in context
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
......@@ -217,7 +222,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index) {
ValueType ret = 0;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ret = impl::IndexSelect<XPU, DType>(array, index);
});
......@@ -233,7 +238,7 @@ template double IndexSelect<double>(NDArray array, uint64_t index);
NDArray Scatter(NDArray array, IdArray indices) {
NDArray ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Scatter", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, {
ret = impl::Scatter<XPU, DType, IdType>(array, indices);
......@@ -245,7 +250,7 @@ NDArray Scatter(NDArray array, IdArray indices) {
NDArray Repeat(NDArray array, IdArray repeats) {
NDArray ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Repeat", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, {
ret = impl::Repeat<XPU, DType, IdType>(array, repeats);
......@@ -257,7 +262,7 @@ NDArray Repeat(NDArray array, IdArray repeats) {
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray ret;
ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, "Relabel_", {
ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
ret = impl::Relabel_<XPU, IdType>(arrays);
});
......@@ -268,7 +273,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
std::tuple<NDArray, IdArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));
});
......@@ -285,7 +290,7 @@ template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
std::pair<NDArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "ConcatSlices", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {
ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);
......@@ -299,7 +304,7 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
bool ret = false;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRIsNonZero", {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
......@@ -307,7 +312,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRIsNonZero", {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
......@@ -315,7 +320,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
bool CSRHasDuplicate(CSRMatrix csr) {
bool ret = false;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRHasDuplicate", {
ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
});
return ret;
......@@ -323,7 +328,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
int64_t ret = 0;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowNNZ", {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
......@@ -331,7 +336,7 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowNNZ", {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
......@@ -339,7 +344,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowColumnIndices", {
ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
});
return ret;
......@@ -347,7 +352,7 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowData", {
ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
});
return ret;
......@@ -355,7 +360,7 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, row, col);
});
return ret;
......@@ -363,7 +368,7 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
});
return ret;
......@@ -372,7 +377,7 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetDataAndIndices", {
ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
});
return ret;
......@@ -380,8 +385,10 @@ std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix CSRTranspose(CSRMatrix csr) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRTranspose<XPU, IdType>(csr);
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRTranspose", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRTranspose<XPU, IdType>(csr);
});
});
return ret;
}
......@@ -389,13 +396,13 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
COOMatrix ret;
if (data_as_order) {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
});
});
} else {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOO<XPU, IdType>(csr);
});
......@@ -406,7 +413,7 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceRows", {
ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
});
return ret;
......@@ -414,7 +421,7 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceRows", {
ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
});
return ret;
......@@ -422,21 +429,21 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceMatrix", {
ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
});
return ret;
}
void CSRSort_(CSRMatrix* csr) {
ATEN_CSR_SWITCH(*csr, XPU, IdType, {
ATEN_CSR_SWITCH(*csr, XPU, IdType, "CSRSort_", {
impl::CSRSort_<XPU, IdType>(csr);
});
}
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
ret = impl::CSRRemove<XPU, IdType>(csr, entries);
});
return ret;
......@@ -445,7 +452,7 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, {
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSampling", {
if (IsNullArray(prob)) {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else {
......@@ -461,7 +468,7 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, {
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseTopk", {
ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
ret = impl::CSRRowWiseTopk<XPU, IdType, DType>(
mat, rows, k, weight, ascending);
......@@ -474,7 +481,7 @@ COOMatrix CSRRowWiseTopk(
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
bool ret = false;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
});
return ret;
......@@ -482,7 +489,7 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
});
return ret;
......@@ -490,7 +497,7 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
bool COOHasDuplicate(COOMatrix coo) {
bool ret = false;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOHasDuplicate", {
ret = impl::COOHasDuplicate<XPU, IdType>(coo);
});
return ret;
......@@ -498,7 +505,7 @@ bool COOHasDuplicate(COOMatrix coo) {
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
int64_t ret = 0;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowNNZ", {
ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
});
return ret;
......@@ -506,7 +513,7 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowNNZ", {
ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
});
return ret;
......@@ -514,7 +521,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row) {
std::pair<NDArray, NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
});
return ret;
......@@ -522,7 +529,7 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row)
NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
ret = impl::COOGetData<XPU, IdType>(coo, row, col);
});
return ret;
......@@ -531,31 +538,29 @@ NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
std::vector<NDArray> COOGetDataAndIndices(
COOMatrix coo, NDArray rows, NDArray cols) {
std::vector<NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetDataAndIndices", {
ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);
});
return ret;
}
COOMatrix COOTranspose(COOMatrix coo) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOTranspose<XPU, IdType>(coo);
});
return ret;
return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);
}
CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOToCSR<XPU, IdType>(coo);
ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
ret = impl::COOToCSR<XPU, IdType>(coo);
});
});
return ret;
}
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);
});
return ret;
......@@ -563,7 +568,7 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
ret = impl::COOSliceRows<XPU, IdType>(coo, rows);
});
return ret;
......@@ -571,7 +576,7 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceMatrix", {
ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);
});
return ret;
......@@ -579,15 +584,17 @@ COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
COOMatrix COOSort(COOMatrix mat, bool sort_column) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
ret = impl::COOSort<XPU, IdType>(mat, sort_column);
ATEN_XPU_SWITCH_CUDA(mat.row->ctx.device_type, XPU, "COOSort", {
ATEN_ID_TYPE_SWITCH(mat.row->dtype, IdType, {
ret = impl::COOSort<XPU, IdType>(mat, sort_column);
});
});
return ret;
}
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
ret = impl::COORemove<XPU, IdType>(coo, entries);
});
return ret;
......@@ -596,7 +603,7 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
if (IsNullArray(prob)) {
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else {
......@@ -612,7 +619,7 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseTopk", {
ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
ret = impl::COORowWiseTopk<XPU, IdType, DType>(
mat, rows, k, weight, ascending);
......@@ -623,7 +630,7 @@ COOMatrix COORowWiseTopk(
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
std::pair<COOMatrix, IdArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ATEN_COO_SWITCH(coo, XPU, IdType, "COOCoalesce", {
ret = impl::COOCoalesce<XPU, IdType>(coo);
});
return ret;
......
/*!
* Copyright (c) 2019 by Contributors
* \file array/cuda/array_op_impl.cu
* \brief Array operator GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
int FindNumThreads(int dim, int max_nthrs) {
int ret = max_nthrs;
while (ret > dim) {
ret = ret >> 1;
}
return ret;
}
///////////////////////////// Range /////////////////////////////
template <typename IdType>
__global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = low + tx;
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
if (length == 0)
return ret;
IdType* ret_data = static_cast<IdType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int nt = FindNumThreads(length, 1024);
int nb = (length + nt - 1) / nt;
_RangeKernel<IdType><<<nb, nt, 0, thr_entry->stream>>>(ret_data, low, length);
return ret;
}
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// AsNumBits /////////////////////////////
template <typename InType, typename OutType>
__global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = in[tx];
tx += stride_x;
}
}
template <DLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int nt = FindNumThreads(length, 1024);
int nb = (length + nt - 1) / nt;
if (bits == 32) {
_CastKernel<IdType, int32_t><<<nb, nt, 0, thr_entry->stream>>>(
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length);
} else {
_CastKernel<IdType, int64_t><<<nb, nt, 0, thr_entry->stream>>>(
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length);
}
return ret;
}
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/coo2csr.cc
* \brief COO2CSR
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
CHECK(sizeof(IdType) == 4) << "CUDA COOToCSR does not support int64.";
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx);
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
NDArray row = coo.row, col = coo.col, data = coo.data;
int32_t* row_ptr = static_cast<int32_t*>(row->data);
int32_t* col_ptr = static_cast<int32_t*>(col->data);
int32_t* data_ptr = aten::IsNullArray(data) ? nullptr : static_cast<int32_t*>(data->data);
if (!coo.row_sorted) {
// make a copy of row and col because sort is done in-place
row = row.CopyTo(row->ctx);
col = col.CopyTo(col->ctx);
row_ptr = static_cast<int32_t*>(row->data);
col_ptr = static_cast<int32_t*>(col->data);
if (aten::IsNullArray(data)) {
// create the index array
data = aten::Range(0, row->shape[0], row->dtype.bits, row->ctx);
data_ptr = static_cast<int32_t*>(data->data);
}
// sort row
size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
row_ptr,
col_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
row_ptr,
col_ptr,
data_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace);
}
NDArray indptr = aten::NewIdArray(coo.num_rows + 1, row->ctx, row->dtype.bits);
int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
CUSPARSE_CALL(cusparseXcoo2csr(
thr_entry->cusparse_handle,
row_ptr,
row->shape[0],
coo.num_rows,
indptr_ptr,
CUSPARSE_INDEX_BASE_ZERO));
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, col, data, false);
}
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/coo_sort.cc
* \brief Sort COO index
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
COOMatrix COOSort(COOMatrix coo, bool sort_column) {
CHECK(sizeof(IdType) == 4) << "CUDA COOSort does not support int64.";
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx);
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
NDArray row = coo.row.CopyTo(coo.row->ctx);
NDArray col = coo.col.CopyTo(coo.col->ctx);
NDArray data;
if (aten::IsNullArray(coo.data)) {
// create the index array
data = aten::Range(0, row->shape[0], row->dtype.bits, row->ctx);
} else {
data = coo.data.CopyTo(coo.data->ctx);
}
int32_t* row_ptr = static_cast<int32_t*>(row->data);
int32_t* col_ptr = static_cast<int32_t*>(col->data);
int32_t* data_ptr = static_cast<int32_t*>(data->data);
// sort row
size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
row_ptr,
col_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
row_ptr,
col_ptr,
data_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace);
if (sort_column) {
// First create a row indptr array and then call csrsort
int32_t* indptr = static_cast<int32_t*>(
device->AllocWorkspace(row->ctx, (coo.num_rows + 1) * sizeof(IdType)));
CUSPARSE_CALL(cusparseXcoo2csr(
thr_entry->cusparse_handle,
row_ptr,
row->shape[0],
coo.num_rows,
indptr,
CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
thr_entry->cusparse_handle,
coo.num_rows,
coo.num_cols,
row->shape[0],
indptr,
col_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseXcsrsort(
thr_entry->cusparse_handle,
coo.num_rows,
coo.num_cols,
row->shape[0],
descr,
indptr,
col_ptr,
data_ptr,
workspace));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
device->FreeWorkspace(row->ctx, workspace);
device->FreeWorkspace(row->ctx, indptr);
}
return COOMatrix(coo.num_rows, coo.num_cols,
row, col, data, true, sort_column);
}
template COOMatrix COOSort<kDLGPU, int32_t>(COOMatrix coo, bool sort_column);
template COOMatrix COOSort<kDLGPU, int64_t>(COOMatrix coo, bool sort_column);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/csr2coo.cc
* \brief CSR2COO
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
CHECK(sizeof(IdType) == 4) << "CUDA CSRToCOO does not support int64.";
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
NDArray row = aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);
int32_t* row_ptr = static_cast<int32_t*>(row->data);
CUSPARSE_CALL(cusparseXcsr2coo(
thr_entry->cusparse_handle,
indptr_ptr,
indices->shape[0],
csr.num_rows,
row_ptr,
CUSPARSE_INDEX_BASE_ZERO));
return COOMatrix(csr.num_rows, csr.num_cols,
row, indices, data,
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<XPU, IdType>(csr);
if (aten::IsNullArray(coo.data))
return coo;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx);
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
NDArray row = coo.row, col = coo.col, data = coo.data;
int32_t* row_ptr = static_cast<int32_t*>(row->data);
int32_t* col_ptr = static_cast<int32_t*>(col->data);
int32_t* data_ptr = static_cast<int32_t*>(data->data);
size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
data_ptr,
row_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle,
coo.num_rows, coo.num_cols,
row->shape[0],
data_ptr,
row_ptr,
col_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace);
return coo;
}
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/csr_transpose.cc
* \brief CSR transpose (convert to CSC)
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
CHECK(sizeof(IdType) == 4) << "CUDA CSR2CSC does not support int64.";
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int64_t nnz = indices->shape[0];
const auto& ctx = indptr->ctx;
const auto bits = indptr->dtype.bits;
if (aten::IsNullArray(data))
data = aten::Range(0, nnz, bits, ctx);
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
const void* data_ptr = data->data;
NDArray t_indptr = aten::NewIdArray(csr.num_cols + 1, ctx, bits);
NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data);
int32_t* t_indices_ptr = static_cast<int32_t*>(t_indices->data);
void* t_data_ptr = t_data->data;
#if __CUDA_API_VERSION >= 10010
auto device = runtime::DeviceAPI::Get(csr.indptr->ctx);
// workspace
size_t workspace_size;
CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize(
thr_entry->cusparse_handle,
csr.num_rows, csr.num_cols, nnz,
data_ptr, indptr_ptr, indices_ptr,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUDA_R_32F,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseCsr2cscEx2(
thr_entry->cusparse_handle,
csr.num_rows, csr.num_cols, nnz,
data_ptr, indptr_ptr, indices_ptr,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
workspace));
device->FreeWorkspace(ctx, workspace);
#else
CUSPARSE_CALL(cusparseScsr2csc(
thr_entry->cusparse_handle,
csr.num_rows, csr.num_cols, nnz,
static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr,
static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO));
#endif
return CSRMatrix(csr.num_cols, csr.num_rows,
t_indptr, t_indices, t_data,
false);
}
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -15,24 +15,6 @@
#include <vector>
#include <string>
using dgl::runtime::operator<<;
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator<<(std::ostream& os, const DLContext& ctx) {
std::string device_name;
switch (ctx.device_type) {
case kDLCPU:
device_name = "CPU";
break;
case kDLGPU:
device_name = "GPU";
break;
default:
device_name = "Unknown device";
}
return os << device_name << ":" << ctx.device_id;
}
namespace dgl {
// Communicator handler type
......
......@@ -275,7 +275,7 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(
const int64_t bits = NumBits();
if (bits == 32) {
return FlattenImpl<int32_t>(etypes);
} else if (bits == 64) {
} else {
return FlattenImpl<int64_t>(etypes);
}
}
......
......@@ -51,7 +51,7 @@ std::pair<IdArray, TypeArray> RandomWalk(
TypeArray vtypes;
IdArray vids;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, {
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalk", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
vids = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
......@@ -72,7 +72,7 @@ std::pair<IdArray, TypeArray> RandomWalkWithRestart(
TypeArray vtypes;
IdArray vids;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, {
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithRestart", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
vids = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob);
......@@ -93,7 +93,7 @@ std::pair<IdArray, TypeArray> RandomWalkWithStepwiseRestart(
TypeArray vtypes;
IdArray vids;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, {
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithStepwiseRestart", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
vids = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
......
......@@ -123,6 +123,14 @@ size_t NDArray::GetSize() const {
return GetDataSize(data_->dl_tensor);
}
int64_t NDArray::NumElements() const {
int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i];
}
return size;
}
bool NDArray::IsContiguous() const {
CHECK(data_ != nullptr);
if (data_->dl_tensor.strides == nullptr)
......
......@@ -3,6 +3,12 @@
#include <dgl/runtime/ndarray.h>
static constexpr DLContext CTX = DLContext{kDLCPU, 0};
static constexpr DLContext CPU = DLContext{kDLCPU, 0};
#ifdef DGL_USE_CUDA
static constexpr DLContext GPU = DLContext{kDLGPU, 0};
#endif
template <typename T>
inline T* Ptr(dgl::runtime::NDArray nd) {
return static_cast<T*>(nd->data);
......@@ -29,6 +35,9 @@ inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {
return false;
num *= a1->shape[i];
}
if (a1->ctx != a2->ctx) return false;
a1 = a1.CopyTo(CPU);
a2 = a2.CopyTo(CPU);
for (int64_t i = 0; i < num; ++i)
if (static_cast<T*>(a1->data)[i] != static_cast<T*>(a2->data)[i])
return false;
......@@ -46,6 +55,4 @@ inline bool IsInArray(dgl::runtime::NDArray a, T x) {
return false;
}
static constexpr DLContext CTX = DLContext{kDLCPU, 0};
#endif // TEST_COMMON_H_
......@@ -25,14 +25,22 @@ TEST(ArrayTest, TestCreate) {
ASSERT_EQ(Len(a), 0);
};
TEST(ArrayTest, TestRange) {
IdArray a = aten::Range(10, 10, 64, CTX);
void _TestRange(DLContext ctx) {
IdArray a = aten::Range(10, 10, 64, ctx);
ASSERT_EQ(Len(a), 0);
a = aten::Range(10, 20, 32, CTX);
a = aten::Range(10, 20, 32, ctx);
ASSERT_EQ(Len(a), 10);
ASSERT_EQ(a->dtype.bits, 32);
a = a.CopyTo(CPU);
for (int i = 0; i < 10; ++i)
ASSERT_EQ(Ptr<int32_t>(a)[i], i + 10);
}
TEST(ArrayTest, TestRange) {
_TestRange(CPU);
#ifdef DGL_USE_CUDA
_TestRange(GPU);
#endif
};
TEST(ArrayTest, TestFull) {
......@@ -61,12 +69,20 @@ TEST(ArrayTest, TestClone) {
}
};
TEST(ArrayTest, TestAsNumBits) {
IdArray a = aten::Range(0, 10, 32, CTX);
void _TestNumBits(DLContext ctx) {
IdArray a = aten::Range(0, 10, 32, ctx);
a = aten::AsNumBits(a, 64);
ASSERT_EQ(a->dtype.bits, 64);
a = a.CopyTo(CPU);
for (int i = 0; i < 10; ++i)
ASSERT_EQ(PI64(a)[i], i);
}
TEST(ArrayTest, TestAsNumBits) {
_TestNumBits(CPU);
#ifdef DGL_USE_CUDA
_TestNumBits(GPU);
#endif
};
template <typename IDX>
......
......@@ -8,7 +8,7 @@ using namespace dgl::runtime;
namespace {
template <typename IDX>
aten::CSRMatrix CSR1() {
aten::CSRMatrix CSR1(DLContext ctx = CTX) {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
......@@ -16,14 +16,14 @@ aten::CSRMatrix CSR1() {
// data: [0, 2, 3, 1, 4]
return aten::CSRMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, ctx),
false);
}
template <typename IDX>
aten::CSRMatrix CSR2() {
aten::CSRMatrix CSR2(DLContext ctx = CTX) {
// has duplicate entries
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
......@@ -32,14 +32,14 @@ aten::CSRMatrix CSR2() {
// data: [0, 2, 5, 3, 1, 4]
return aten::CSRMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
false);
}
template <typename IDX>
aten::COOMatrix COO1() {
aten::COOMatrix COO1(DLContext ctx = CTX) {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
......@@ -49,12 +49,12 @@ aten::COOMatrix COO1() {
// col : [1, 2, 2, 0, 3]
return aten::COOMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, CTX));
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, ctx));
}
template <typename IDX>
aten::COOMatrix COO2() {
aten::COOMatrix COO2(DLContext ctx = CTX) {
// has duplicate entries
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
......@@ -65,40 +65,40 @@ aten::COOMatrix COO2() {
// col : [1, 2, 2, 0, 3, 2]
return aten::COOMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, CTX));
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, ctx));
}
template <typename IDX>
aten::CSRMatrix SR_CSR3() {
aten::CSRMatrix SR_CSR3(DLContext ctx) {
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
return aten::CSRMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
false);
}
template <typename IDX>
aten::CSRMatrix SRC_CSR3() {
aten::CSRMatrix SRC_CSR3(DLContext ctx) {
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
return aten::CSRMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
false);
}
template <typename IDX>
aten::COOMatrix COO3() {
aten::COOMatrix COO3(DLContext ctx) {
// has duplicate entries
// [[0, 1, 2, 0, 0],
// [1, 0, 0, 0, 0],
......@@ -108,11 +108,11 @@ aten::COOMatrix COO3() {
// col : [2, 2, 1, 0, 3, 2]
return aten::COOMatrix(
4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, CTX),
aten::VecToIdArray(std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX)*8, CTX));
aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX)*8, ctx));
}
}
} // namespace
template <typename IDX>
void _TestCSRIsNonZero() {
......@@ -227,8 +227,8 @@ TEST(SpmatTest, TestCSRGetDataAndIndices) {
}
template <typename IDX>
void _TestCSRTranspose() {
auto csr = CSR2<IDX>();
void _TestCSRTranspose(DLContext ctx) {
auto csr = CSR2<IDX>(ctx);
auto csr_t = aten::CSRTranspose(csr);
// [[0, 1, 0, 0],
// [1, 0, 0, 0],
......@@ -238,29 +238,32 @@ void _TestCSRTranspose() {
// data: [3, 0, 2, 5, 1, 4]
ASSERT_EQ(csr_t.num_rows, 5);
ASSERT_EQ(csr_t.num_cols, 4);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX)*8, CTX);
auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, CTX);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX)*8, ctx);
auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(csr_t.data, td));
}
TEST(SpmatTest, TestCSRTranspose) {
_TestCSRTranspose<int32_t>();
_TestCSRTranspose<int64_t>();
_TestCSRTranspose<int32_t>(CPU);
_TestCSRTranspose<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCSRTranspose<int32_t>(GPU);
#endif
}
template <typename IDX>
void _TestCSRToCOO() {
auto csr = CSR2<IDX>();
void _TestCSRToCOO(DLContext ctx) {
auto csr = CSR2<IDX>(ctx);
{
auto coo = CSRToCOO(csr, false);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, CTX);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tc));
ASSERT_TRUE(ArrayEQ<IDX>(coo.data, td));
......@@ -269,15 +272,18 @@ void _TestCSRToCOO() {
auto coo = CSRToCOO(csr, true);
ASSERT_EQ(coo.num_rows, 4);
ASSERT_EQ(coo.num_cols, 5);
auto tcoo = COO2<IDX>();
auto tcoo = COO2<IDX>(ctx);
ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tcoo.row));
ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tcoo.col));
}
}
TEST(SpmatTest, TestCSRToCOO) {
_TestCSRToCOO<int32_t>();
_TestCSRToCOO<int64_t>();
_TestCSRToCOO<int32_t>(CPU);
_TestCSRToCOO<int64_t>(CPU);
#if DGL_USE_CUDA
_TestCSRToCOO<int32_t>(GPU);
#endif
}
template <typename IDX>
......@@ -355,48 +361,40 @@ TEST(SpmatTest, TestCSRHasDuplicate) {
}
template <typename IDX>
void _TestCOOToCSR() {
auto coo = COO1<IDX>();
auto csr = CSR1<IDX>();
void _TestCOOToCSR(DLContext ctx) {
auto coo = COO1<IDX>(ctx);
auto csr = CSR1<IDX>(ctx);
auto tcsr = aten::COOToCSR(coo);
ASSERT_EQ(coo.num_rows, csr.num_rows);
ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(csr.data, tcsr.data));
coo = COO2<IDX>();
csr = CSR2<IDX>();
coo = COO2<IDX>(ctx);
csr = CSR2<IDX>(ctx);
tcsr = aten::COOToCSR(coo);
ASSERT_EQ(coo.num_rows, csr.num_rows);
ASSERT_EQ(coo.num_cols, csr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(csr.data, tcsr.data));
coo = COO1<IDX>();
coo = COO1<IDX>(ctx);
auto rs_coo = aten::COOSort(coo, false);
auto rs_csr = CSR1<IDX>();
auto rs_csr = CSR1<IDX>(ctx);
auto rs_tcsr = aten::COOToCSR(rs_coo);
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
coo = COO3<IDX>();
coo = COO3<IDX>(ctx);
rs_coo = aten::COOSort(coo, false);
rs_csr = SR_CSR3<IDX>();
rs_csr = SR_CSR3<IDX>(ctx);
rs_tcsr = aten::COOToCSR(rs_coo);
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
coo = COO1<IDX>();
coo = COO1<IDX>(ctx);
auto src_coo = aten::COOSort(coo, true);
auto src_csr = CSR1<IDX>();
auto src_csr = CSR1<IDX>(ctx);
auto src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
......@@ -404,9 +402,9 @@ void _TestCOOToCSR() {
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.indices, src_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(src_csr.data, src_tcsr.data));
coo = COO3<IDX>();
coo = COO3<IDX>(ctx);
src_coo = aten::COOSort(coo, true);
src_csr = SRC_CSR3<IDX>();
src_csr = SRC_CSR3<IDX>(ctx);
src_tcsr = aten::COOToCSR(src_coo);
ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
......@@ -416,8 +414,11 @@ void _TestCOOToCSR() {
}
TEST(SpmatTest, TestCOOToCSR) {
_TestCOOToCSR<int32_t>();
_TestCOOToCSR<int64_t>();
_TestCOOToCSR<int32_t>(CPU);
_TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCOOToCSR<int32_t>(GPU);
#endif
}
template <typename IDX>
......@@ -434,8 +435,8 @@ TEST(SpmatTest, TestCOOHasDuplicate) {
}
template <typename IDX>
void _TestCOOSort() {
auto coo = COO3<IDX>();
void _TestCOOSort(DLContext ctx) {
auto coo = COO3<IDX>(ctx);
auto sr_coo = COOSort(coo, false);
ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
......@@ -460,25 +461,22 @@ void _TestCOOSort() {
// row : [0, 0, 0, 1, 2, 2]
// col : [1, 2, 2, 0, 2, 3]
auto sort_row = aten::VecToIdArray(
std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, CTX);
auto unsort_col = aten::VecToIdArray(
std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
auto unsort_col_data = aten::VecToIdArray(
std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
auto sort_col = aten::VecToIdArray(
std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx);
auto sort_col_data = aten::VecToIdArray(
std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.row, sort_row));
ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.col, unsort_col));
ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.data, unsort_col_data));
ASSERT_TRUE(ArrayEQ<IDX>(src_coo.row, sort_row));
ASSERT_TRUE(ArrayEQ<IDX>(src_coo.col, sort_col));
ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
}
TEST(SpmatTest, TestCOOSort) {
_TestCOOSort<int32_t>();
_TestCOOSort<int64_t>();
_TestCOOSort<int32_t>(CPU);
_TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestCOOSort<int32_t>(GPU);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment