"src/turbomind/models/vscode:/vscode.git/clone" did not exist on "2f80c556fbdb2e37e93446913101e08ce8bfcc4c"
Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
......@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals(
total_len += traces[i].size();
}
IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0});
DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) {
......@@ -79,7 +79,7 @@ IdArray ComputeMergedSections(
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
}
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0;
......@@ -96,7 +96,7 @@ IdArray ComputeMergedSections(
} // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
......@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
......@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
......@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front;
}
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
......@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
......@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front;
}
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&,
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&,
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
......
......@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
......@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret;
}
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool);
template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
} // namespace impl
} // namespace aten
......
......@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
......@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16
......@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) {
#endif
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0},
array->dtype);
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return reinterpret_cast<DType&>(ret);
}
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index);
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index);
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
} // namespace impl
} // namespace aten
......
......@@ -26,7 +26,7 @@ struct IsNonZeroIndex {
const IdType * array_;
};
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
......@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}
template IdArray NonZero<kDLGPU, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray);
template IdArray NonZero<kDGLCUDA, int32_t>(IdArray);
template IdArray NonZero<kDGLCUDA, int64_t>(IdArray);
} // namespace impl
} // namespace aten
......
......@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
......@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op>
......@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
......@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
......@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
const int64_t len = rhs->shape[0];
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
......@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel(
......@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
......@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret;
}
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
......@@ -210,9 +210,9 @@ __global__ void _FullKernel(
}
}
template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
......@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret;
}
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range /////////////////////////////
......@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
}
}
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
......@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret;
}
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
template IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ //////////////////////////////
......@@ -278,7 +278,7 @@ __global__ void _RelabelKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray all_nodes = Concat(arrays);
const int64_t total_length = all_nodes->shape[0];
......@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
&num_induced, 0,
sizeof(num_induced),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1});
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device);
......@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes;
}
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits /////////////////////////////
......@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
......@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
}
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl
} // namespace aten
......
......@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value,
}
}
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
......@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd);
}
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten
......
......@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
......@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx);
}
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits);
} // namespace impl
} // namespace aten
......
......@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
LOG(FATAL) << "Unreachable code.";
return {};
}
template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound(
}
template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits;
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted);
}
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
......
......@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) {
return bits;
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows);
......@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
}
}
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted /////////////////////////////
......@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
const auto& ctx = coo.row->ctx;
......@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted};
}
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
......
......@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -77,7 +77,7 @@ __global__ void _RepeatKernel(
}
template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -99,18 +99,18 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
......@@ -156,8 +156,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
const auto& sorted = Sort(coo.data);
......@@ -173,8 +173,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return coo;
}
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
......
......@@ -17,7 +17,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
......@@ -38,7 +38,7 @@ NDArray CSRGetData(
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
// TODO(minjie): use binary search for sorted csr
......@@ -53,24 +53,24 @@ NDArray CSRGetData(
}
#ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>(
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>(
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
template NDArray CSRGetData<kDLGPU, int32_t, float>(
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>(
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>(
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>(
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl
......
......@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten
......
......@@ -34,7 +34,7 @@ __global__ void _SegmentIsSorted(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -53,16 +53,16 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret;
}
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
LOG(FATAL) << "Unreachable codes";
}
template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -108,7 +108,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
}
template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
......@@ -147,8 +147,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace);
}
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
} // namespace impl
} // namespace aten
......
......@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten
......
......@@ -13,14 +13,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
}
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
......
......@@ -105,7 +105,7 @@ IdArray _PerformFilter(
&num_unique, 0,
sizeof(num_unique),
ctx,
DGLContext{kDLCPU, 0},
DGLContext{kDGLCPU, 0},
test->dtype);
// insert items into set
......@@ -150,13 +150,13 @@ class CudaFilterSet : public Filter {
} // namespace
template<DLDeviceType XPU, typename IdType>
template<DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
}
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set);
template FilterRef CreateSetFilter<kDGLCUDA, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDGLCUDA, int64_t>(IdArray set);
} // namespace array
} // namespace dgl
......@@ -47,7 +47,7 @@ __global__ void _DisjointUnionKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
IdArray n_rows = NewIdArray(n, CPU, nbits);
......@@ -71,10 +71,10 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
int64_t n_arrs, int n_elms,
DGLContext ctx, DGLType dtype, cudaStream_t stream) {
DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx);
int nt = 256;
int nb = (n_elms + nt - 1) / nt;
......@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs,
DGLContext{kDLCPU, 0}, ctx, dtype);
DGLContext{kDGLCPU, 0}, ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel,
nb, nt, 0, stream,
......@@ -94,7 +94,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->FreeWorkspace(ctx, arrs_dev);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);
......@@ -133,17 +133,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
IdType n_elements = 0;
device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
// Union src array
......@@ -176,8 +176,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted);
}
template COOMatrix DisjointUnionCoo<kDLGPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLGPU, int64_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl
} // namespace aten
......
......@@ -394,74 +394,74 @@ void GatherMMScatter(const NDArray A,
}
template void GatherMM<kDLGPU, int32_t, 16>(
template void GatherMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 16>(
template void GatherMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 32>(
template void GatherMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 32>(
template void GatherMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 64>(
template void GatherMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 64>(
template void GatherMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLGPU, int32_t, 16>(
template void GatherMMScatter<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 16>(
template void GatherMMScatter<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 32>(
template void GatherMMScatter<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 32>(
template void GatherMMScatter<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 64>(
template void GatherMMScatter<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 64>(
template void GatherMMScatter<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLGPU, int32_t, 16>(
template void SegmentMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 16>(
template void SegmentMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 32>(
template void SegmentMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 32>(
template void SegmentMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 64>(
template void SegmentMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 64>(
template void SegmentMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten
......
......@@ -26,7 +26,7 @@
} \
} else { \
constexpr bool UseBcast = true; \
const DLContext ctx = (CTX); \
const DGLContext ctx = (CTX); \
const auto device = runtime::DeviceAPI::Get(ctx); \
(LHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \
......
......@@ -93,7 +93,7 @@ struct IsNotMinusOne {
template <typename IdType>
void SortOrderedPairs(
runtime::DeviceAPI* device,
DLContext ctx,
DGLContext ctx,
IdType* major,
IdType* minor,
IdType* tmp_major,
......@@ -128,7 +128,7 @@ void SortOrderedPairs(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
......@@ -211,9 +211,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result;
}
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int32_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int64_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment