Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals( ...@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals(
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1}, DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -79,7 +79,7 @@ IdArray ComputeMergedSections( ...@@ -79,7 +79,7 @@ IdArray ComputeMergedSections(
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -96,7 +96,7 @@ IdArray ComputeMergedSections( ...@@ -96,7 +96,7 @@ IdArray ComputeMergedSections(
} // namespace } // namespace
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { ...@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front; return front;
} }
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType* src_data = static_cast<IdType*>(source->data);
...@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source, IdArray source,
const bool has_reverse_edge, const bool has_reverse_edge,
...@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front; return front;
} }
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&,
IdArray, IdArray,
const bool, const bool,
const bool, const bool,
const bool); const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray, IdArray,
const bool, const bool,
const bool, const bool,
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret; return ret;
} }
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool); template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool); template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
...@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret; return ret;
} }
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16 #ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif #endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16 #ifdef USE_FP16
...@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) { ...@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) {
#endif #endif
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0, static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0}, sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
array->dtype);
return reinterpret_cast<DType&>(ret); return reinterpret_cast<DType&>(ret);
} }
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index); template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16 #ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index); template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif #endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -26,7 +26,7 @@ struct IsNonZeroIndex { ...@@ -26,7 +26,7 @@ struct IsNonZeroIndex {
const IdType * array_; const IdType * array_;
}; };
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) { IdArray NonZero(IdArray array) {
const auto& ctx = array->ctx; const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) { ...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0); return ret.CreateView({num_nonzeros}, ret->dtype, 0);
} }
template IdArray NonZero<kDLGPU, int32_t>(IdArray); template IdArray NonZero<kDGLCUDA, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray); template IdArray NonZero<kDGLCUDA, int64_t>(IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) { IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
...@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) { IdArray BinaryElewise(IdArray lhs, IdType rhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
...@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) { IdArray BinaryElewise(IdType lhs, IdArray rhs) {
const int64_t len = rhs->shape[0]; const int64_t len = rhs->shape[0];
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits); IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
...@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel( __global__ void _UnaryElewiseKernel(
...@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel( ...@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) { IdArray UnaryElewise(IdArray lhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret; return ret;
} }
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
...@@ -210,9 +210,9 @@ __global__ void _FullKernel( ...@@ -210,9 +210,9 @@ __global__ void _FullKernel(
} }
} }
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx); NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
...@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) { ...@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret; return ret;
} }
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16 #ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif #endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
...@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) { ...@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) { IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low"; CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low; const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
...@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) { ...@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret; return ret;
} }
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext); template IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext); template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ ////////////////////////////// ///////////////////////////// Relabel_ //////////////////////////////
...@@ -278,7 +278,7 @@ __global__ void _RelabelKernel( ...@@ -278,7 +278,7 @@ __global__ void _RelabelKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) { IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray all_nodes = Concat(arrays); IdArray all_nodes = Concat(arrays);
const int64_t total_length = all_nodes->shape[0]; const int64_t total_length = all_nodes->shape[0];
...@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
&num_induced, 0, &num_induced, 0,
sizeof(num_induced), sizeof(num_induced),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, 64, 1}); DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device); device->FreeWorkspace(ctx, num_induced_device);
...@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes; return induced_nodes;
} }
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits ///////////////////////////// ///////////////////////////// AsNumBits /////////////////////////////
...@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) { ...@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim); const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx); IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements(); const int64_t length = ret.NumElements();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
...@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
} }
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value, ...@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value,
} }
} }
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) { void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>(); const IdType* idx = index.Ptr<IdType>();
...@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd); idx, val, len, outd);
} }
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
const auto& ctx = array->ctx; const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -14,14 +14,14 @@ using runtime::NDArray; ...@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
LOG(FATAL) << "Unreachable code."; LOG(FATAL) << "Unreachable code.";
return {}; return {};
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound( ...@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound(
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits; const auto nbits = coo.row->dtype.bits;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { ...@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted); indptr, coo.col, coo.data, col_sorted);
} }
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) { ...@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) {
return bits; return bits;
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows); const int row_bits = _NumberOfBits(coo->num_rows);
...@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
} }
} }
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted ///////////////////////////// ///////////////////////////// COOIsSorted /////////////////////////////
...@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel( ...@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) { std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
...@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted}; return {row_sorted, col_sorted};
} }
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -14,14 +14,14 @@ using runtime::NDArray; ...@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) { COOMatrix CSRToCOO(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes"; LOG(FATAL) << "Unreachable codes";
return {}; return {};
} }
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -77,7 +77,7 @@ __global__ void _RepeatKernel( ...@@ -77,7 +77,7 @@ __global__ void _RepeatKernel(
} }
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -99,18 +99,18 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) { ...@@ -99,18 +99,18 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true, csr.sorted); true, csr.sorted);
} }
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes"; LOG(FATAL) << "Unreachable codes";
return {}; return {};
} }
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data))
return coo; return coo;
...@@ -156,8 +156,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) { ...@@ -156,8 +156,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
} }
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data))
return coo; return coo;
const auto& sorted = Sort(coo.data); const auto& sorted = Sort(coo.data);
...@@ -173,8 +173,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) { ...@@ -173,8 +173,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return coo; return coo;
} }
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -17,7 +17,7 @@ using runtime::NDArray; ...@@ -17,7 +17,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
...@@ -38,7 +38,7 @@ NDArray CSRGetData( ...@@ -38,7 +38,7 @@ NDArray CSRGetData(
const int nt = cuda::FindNumThreads(rstlen); const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
if (return_eids) if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype."; "DType does not match row's dtype.";
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
...@@ -53,24 +53,24 @@ NDArray CSRGetData( ...@@ -53,24 +53,24 @@ NDArray CSRGetData(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>( template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>( template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif #endif
template NDArray CSRGetData<kDLGPU, int32_t, float>( template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>( template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>( template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>( template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>( template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>( template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
......
...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif #endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten } // namespace aten
......
...@@ -34,7 +34,7 @@ __global__ void _SegmentIsSorted( ...@@ -34,7 +34,7 @@ __global__ void _SegmentIsSorted(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) { bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -53,16 +53,16 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -53,16 +53,16 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret; return ret;
} }
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr); template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr); template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) { void CSRSort_(CSRMatrix* csr) {
LOG(FATAL) << "Unreachable codes"; LOG(FATAL) << "Unreachable codes";
} }
template <> template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -108,7 +108,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { ...@@ -108,7 +108,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
} }
template <> template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
...@@ -147,8 +147,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { ...@@ -147,8 +147,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr); template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr); template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif #endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten } // namespace aten
......
...@@ -13,14 +13,14 @@ using runtime::NDArray; ...@@ -13,14 +13,14 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) { CSRMatrix CSRTranspose(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes"; LOG(FATAL) << "Unreachable codes";
return {}; return {};
} }
template <> template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) { ...@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
} }
template <> template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false))); return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
} }
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -105,7 +105,7 @@ IdArray _PerformFilter( ...@@ -105,7 +105,7 @@ IdArray _PerformFilter(
&num_unique, 0, &num_unique, 0,
sizeof(num_unique), sizeof(num_unique),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
test->dtype); test->dtype);
// insert items into set // insert items into set
...@@ -150,13 +150,13 @@ class CudaFilterSet : public Filter { ...@@ -150,13 +150,13 @@ class CudaFilterSet : public Filter {
} // namespace } // namespace
template<DLDeviceType XPU, typename IdType> template<DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) { FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set)); return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
} }
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set); template FilterRef CreateSetFilter<kDGLCUDA, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set); template FilterRef CreateSetFilter<kDGLCUDA, int64_t>(IdArray set);
} // namespace array } // namespace array
} // namespace dgl } // namespace dgl
...@@ -47,7 +47,7 @@ __global__ void _DisjointUnionKernel( ...@@ -47,7 +47,7 @@ __global__ void _DisjointUnionKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) { std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdType n = coos.size(), nbits = coos[0].row->dtype.bits; IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
IdArray n_rows = NewIdArray(n, CPU, nbits); IdArray n_rows = NewIdArray(n, CPU, nbits);
...@@ -71,10 +71,10 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -71,10 +71,10 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(n_elms.CopyTo(coos[0].row->ctx), true)); CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
int64_t n_arrs, int n_elms, int64_t n_arrs, int n_elms,
DGLContext ctx, DGLType dtype, cudaStream_t stream) { DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
int nt = 256; int nt = 256;
int nb = (n_elms + nt - 1) / nt; int nb = (n_elms + nt - 1) / nt;
...@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, ...@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->CopyDataFromTo( device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs, arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs,
DGLContext{kDLCPU, 0}, ctx, dtype); DGLContext{kDGLCPU, 0}, ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel, CUDA_KERNEL_CALL(_DisjointUnionKernel,
nb, nt, 0, stream, nb, nt, 0, stream,
...@@ -94,7 +94,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, ...@@ -94,7 +94,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->FreeWorkspace(ctx, arrs_dev); device->FreeWorkspace(ctx, arrs_dev);
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(coos[0].row->ctx); auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);
...@@ -133,17 +133,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -133,17 +133,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
IdType n_elements = 0; IdType n_elements = 0;
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0, &prefix_elm[coos.size()], 0, &n_elements, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0}, sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype); coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0, &prefix_src[coos.size()], 0, &src_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0}, sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype); coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0, &prefix_dst[coos.size()], 0, &dst_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0}, sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype); coos[0].row->dtype);
// Union src array // Union src array
...@@ -176,8 +176,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -176,8 +176,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted); col_sorted);
} }
template COOMatrix DisjointUnionCoo<kDLGPU, int32_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLGPU, int64_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -394,74 +394,74 @@ void GatherMMScatter(const NDArray A, ...@@ -394,74 +394,74 @@ void GatherMMScatter(const NDArray A,
} }
template void GatherMM<kDLGPU, int32_t, 16>( template void GatherMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 16>( template void GatherMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 32>( template void GatherMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 32>( template void GatherMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 64>( template void GatherMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 64>( template void GatherMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLGPU, int32_t, 16>( template void GatherMMScatter<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 16>( template void GatherMMScatter<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 32>( template void GatherMMScatter<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 32>( template void GatherMMScatter<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 64>( template void GatherMMScatter<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 64>( template void GatherMMScatter<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLGPU, int32_t, 16>( template void SegmentMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 16>( template void SegmentMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 32>( template void SegmentMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 32>( template void SegmentMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 64>( template void SegmentMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 64>( template void SegmentMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten } // namespace aten
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
} \ } \
} else { \ } else { \
constexpr bool UseBcast = true; \ constexpr bool UseBcast = true; \
const DLContext ctx = (CTX); \ const DGLContext ctx = (CTX); \
const auto device = runtime::DeviceAPI::Get(ctx); \ const auto device = runtime::DeviceAPI::Get(ctx); \
(LHS_OFF) = static_cast<int64_t*>( \ (LHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \ device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \
......
...@@ -93,7 +93,7 @@ struct IsNotMinusOne { ...@@ -93,7 +93,7 @@ struct IsNotMinusOne {
template <typename IdType> template <typename IdType>
void SortOrderedPairs( void SortOrderedPairs(
runtime::DeviceAPI* device, runtime::DeviceAPI* device,
DLContext ctx, DGLContext ctx,
IdType* major, IdType* major,
IdType* minor, IdType* minor,
IdType* tmp_major, IdType* tmp_major,
...@@ -128,7 +128,7 @@ void SortOrderedPairs( ...@@ -128,7 +128,7 @@ void SortOrderedPairs(
}; // namespace }; // namespace
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr,
int64_t num_samples, int64_t num_samples,
...@@ -211,9 +211,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -211,9 +211,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result; return result;
} }
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int32_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double); const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int64_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double); const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl }; // namespace impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment