Commit aaaecbc9 authored by lisj's avatar lisj
Browse files

处理kDLGPU为kDLROCM

parent c454d419
......@@ -136,7 +136,7 @@ struct COOMatrix {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
......
......@@ -129,7 +129,7 @@ struct CSRMatrix {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
......
......@@ -46,8 +46,8 @@
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
constexpr auto XPU = kDLGPU; \
} else if ((val) == kDLROCM) { \
constexpr auto XPU = kDLROCM; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
......
......@@ -173,7 +173,7 @@ class NDArray {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
*/
inline void PinMemory_();
/*!
......@@ -303,7 +303,7 @@ class NDArray {
* Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDLROCM: invalid, will throw an error.
*/
DGL_DLL static void PinContainer(Container* ptr);
......@@ -600,7 +600,7 @@ inline const char* TypeCode2Str(int type_code) {
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
switch (device_type) {
case kDLCPU: return "cpu";
case kDLGPU: return "cuda";
case kDLROCM: return "cuda";
case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
......
......@@ -89,12 +89,18 @@ def device_id(ctx):
else:
return ctx.index
__devtype_th_map = {
1: "cpu",
2: "cuda", # cuda device
10: "cuda" # rocm device
}
def to_backend_ctx(dglctx):
dev_type = dglctx.device_type
if dev_type == 1:
return th.device('cpu')
elif dev_type == 2:
return th.device('cuda', dglctx.device_id)
if dev_type in __devtype_th_map:
th_type = __devtype_th_map[dev_type]
return th.device(th_type, dglctx.device_id)
else:
raise ValueError('Unsupported DGL device context:', dglctx)
......
......@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret;
}
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool);
template IdArray CumSum<kDLROCM, int32_t>(IdArray, bool);
template IdArray CumSum<kDLROCM, int64_t>(IdArray, bool);
} // namespace impl
} // namespace aten
......
......@@ -51,18 +51,18 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, __half, int64_t>(NDArray, IdArray);
#endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLROCM, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
......@@ -84,15 +84,15 @@ DType IndexSelect(NDArray array, int64_t index) {
return reinterpret_cast<DType&>(ret);
}
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index);
template int32_t IndexSelect<kDLROCM, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLROCM, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLROCM, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLROCM, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index);
template __half IndexSelect<kDLROCM, __half>(NDArray array, int64_t index);
#endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index);
template float IndexSelect<kDLROCM, float>(NDArray array, int64_t index);
template double IndexSelect<kDLROCM, double>(NDArray array, int64_t index);
} // namespace impl
} // namespace aten
......
......@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}
template IdArray NonZero<kDLGPU, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray);
template IdArray NonZero<kDLROCM, int32_t>(IdArray);
template IdArray NonZero<kDLROCM, int64_t>(IdArray);
} // namespace impl
} // namespace aten
......
......@@ -45,28 +45,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op>
......@@ -95,28 +95,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
......@@ -146,28 +146,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel(
......@@ -195,8 +195,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret;
}
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLROCM, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLROCM, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
......@@ -223,13 +223,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret;
}
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLROCM, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLROCM, int64_t>(int64_t val, int64_t length, DLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx);
template IdArray Full<kDLROCM, __half>(__half val, int64_t length, DLContext ctx);
#endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
template IdArray Full<kDLROCM, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLROCM, double>(double val, int64_t length, DLContext ctx);
///////////////////////////// Range /////////////////////////////
......@@ -261,8 +261,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret;
}
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
template IdArray Range<kDLROCM, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLROCM, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// Relabel_ //////////////////////////////
......@@ -339,8 +339,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes;
}
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLROCM, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLROCM, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits /////////////////////////////
......@@ -375,8 +375,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
}
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLROCM, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLROCM, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl
} // namespace aten
......
......@@ -38,20 +38,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd);
}
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, __half, int32_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, __half, int64_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLROCM, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten
......
......@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx);
}
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLROCM, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLROCM, int64_t>(IdArray, int num_bits);
} // namespace impl
} // namespace aten
......
......@@ -22,7 +22,7 @@ CSRMatrix COOToCSR(COOMatrix coo) {
}
template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -101,7 +101,7 @@ __global__ void _SortedSearchKernelUpperBound(
}
template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits;
hipStream_t stream = runtime::getCurrentCUDAStream();
......@@ -134,8 +134,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted);
}
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
......
......@@ -132,8 +132,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
}
}
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLROCM, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLROCM, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted /////////////////////////////
......@@ -181,8 +181,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted};
}
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLROCM, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLROCM, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
......
......@@ -22,7 +22,7 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
}
template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDLROCM, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -78,7 +78,7 @@ __global__ void _RepeatKernel(
}
template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDLROCM, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
hipStream_t stream = runtime::getCurrentCUDAStream();
......@@ -100,8 +100,8 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLROCM, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLROCM, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
......@@ -110,8 +110,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDLROCM, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLROCM, int32_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
......@@ -157,8 +157,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDLROCM, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLROCM, int64_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
const auto& sorted = Sort(coo.data);
......@@ -174,8 +174,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return coo;
}
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLROCM, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLROCM, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
......
......@@ -53,24 +53,24 @@ NDArray CSRGetData(
}
#ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>(
template NDArray CSRGetData<kDLROCM, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>(
template NDArray CSRGetData<kDLROCM, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
template NDArray CSRGetData<kDLGPU, int32_t, float>(
template NDArray CSRGetData<kDLROCM, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>(
template NDArray CSRGetData<kDLROCM, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>(
template NDArray CSRGetData<kDLROCM, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>(
template NDArray CSRGetData<kDLROCM, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
template NDArray CSRGetData<kDLROCM, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
template NDArray CSRGetData<kDLROCM, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl
......
......@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten
......
......@@ -54,8 +54,8 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret;
}
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLROCM, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLROCM, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
......@@ -63,7 +63,7 @@ void CSRSort_(CSRMatrix* csr) {
}
template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
hipStream_t stream = runtime::getCurrentCUDAStream();
......@@ -109,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
}
template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr) {
hipStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
......@@ -148,8 +148,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace);
}
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr);
} // namespace impl
} // namespace aten
......
......@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten
......
......@@ -20,7 +20,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDLROCM, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
......@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDLROCM, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
}
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLROCM, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLROCM, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten
......
......@@ -156,8 +156,8 @@ FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
}
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set);
template FilterRef CreateSetFilter<kDLROCM, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLROCM, int64_t>(IdArray set);
} // namespace array
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment