"tensoradapter/vscode:/vscode.git/clone" did not exist on "fc6f0b9efc5cbb190685a9684985b9223c707574"
Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals( ...@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals(
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1}, DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -79,7 +79,7 @@ IdArray ComputeMergedSections( ...@@ -79,7 +79,7 @@ IdArray ComputeMergedSections(
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -96,7 +96,7 @@ IdArray ComputeMergedSections( ...@@ -96,7 +96,7 @@ IdArray ComputeMergedSections(
} // namespace } // namespace
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
...@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { ...@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front; return front;
} }
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType* src_data = static_cast<IdType*>(source->data);
...@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source, IdArray source,
const bool has_reverse_edge, const bool has_reverse_edge,
...@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front; return front;
} }
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&,
IdArray, IdArray,
const bool, const bool,
const bool, const bool,
const bool); const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray, IdArray,
const bool, const bool,
const bool, const bool,
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret; return ret;
} }
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool); template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool); template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
...@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret; return ret;
} }
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16 #ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif #endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16 #ifdef USE_FP16
...@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) { ...@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) {
#endif #endif
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0, static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0}, sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
array->dtype);
return reinterpret_cast<DType&>(ret); return reinterpret_cast<DType&>(ret);
} }
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index); template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16 #ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index); template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif #endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -26,7 +26,7 @@ struct IsNonZeroIndex { ...@@ -26,7 +26,7 @@ struct IsNonZeroIndex {
const IdType * array_; const IdType * array_;
}; };
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) { IdArray NonZero(IdArray array) {
const auto& ctx = array->ctx; const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) { ...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0); return ret.CreateView({num_nonzeros}, ret->dtype, 0);
} }
template IdArray NonZero<kDLGPU, int32_t>(IdArray); template IdArray NonZero<kDGLCUDA, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray); template IdArray NonZero<kDGLCUDA, int64_t>(IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) { IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
...@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) { IdArray BinaryElewise(IdArray lhs, IdType rhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
...@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel( ...@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) { IdArray BinaryElewise(IdType lhs, IdArray rhs) {
const int64_t len = rhs->shape[0]; const int64_t len = rhs->shape[0];
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits); IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
...@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel( __global__ void _UnaryElewiseKernel(
...@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel( ...@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel(
} }
} }
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) { IdArray UnaryElewise(IdArray lhs) {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
...@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret; return ret;
} }
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
...@@ -210,9 +210,9 @@ __global__ void _FullKernel( ...@@ -210,9 +210,9 @@ __global__ void _FullKernel(
} }
} }
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx); NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
...@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) { ...@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret; return ret;
} }
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16 #ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif #endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx); template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
...@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) { ...@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) { IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low"; CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low; const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
...@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) { ...@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret; return ret;
} }
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext); template IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext); template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ ////////////////////////////// ///////////////////////////// Relabel_ //////////////////////////////
...@@ -278,7 +278,7 @@ __global__ void _RelabelKernel( ...@@ -278,7 +278,7 @@ __global__ void _RelabelKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) { IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray all_nodes = Concat(arrays); IdArray all_nodes = Concat(arrays);
const int64_t total_length = all_nodes->shape[0]; const int64_t total_length = all_nodes->shape[0];
...@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
&num_induced, 0, &num_induced, 0,
sizeof(num_induced), sizeof(num_induced),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, 64, 1}); DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device); device->FreeWorkspace(ctx, num_induced_device);
...@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes; return induced_nodes;
} }
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits ///////////////////////////// ///////////////////////////// AsNumBits /////////////////////////////
...@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) { ...@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim); const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx); IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements(); const int64_t length = ret.NumElements();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
...@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
} }
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value, ...@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value,
} }
} }
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) { void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>(); const IdType* idx = index.Ptr<IdType>();
...@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd); idx, val, len, outd);
} }
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -13,7 +13,7 @@ using runtime::NDArray; ...@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
const auto& ctx = array->ctx; const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -14,14 +14,14 @@ using runtime::NDArray; ...@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
LOG(FATAL) << "Unreachable code."; LOG(FATAL) << "Unreachable code.";
return {}; return {};
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound( ...@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound(
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits; const auto nbits = coo.row->dtype.bits;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { ...@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted); indptr, coo.col, coo.data, col_sorted);
} }
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) { ...@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) {
return bits; return bits;
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows); const int row_bits = _NumberOfBits(coo->num_rows);
...@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
} }
} }
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted ///////////////////////////// ///////////////////////////// COOIsSorted /////////////////////////////
...@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel( ...@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) { std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
...@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted}; return {row_sorted, col_sorted};
} }
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
This diff is collapsed.
This diff is collapsed.
...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif #endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten } // namespace aten
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
} \ } \
} else { \ } else { \
constexpr bool UseBcast = true; \ constexpr bool UseBcast = true; \
const DLContext ctx = (CTX); \ const DGLContext ctx = (CTX); \
const auto device = runtime::DeviceAPI::Get(ctx); \ const auto device = runtime::DeviceAPI::Get(ctx); \
(LHS_OFF) = static_cast<int64_t*>( \ (LHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \ device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment