Commit aaaecbc9 authored by lisj's avatar lisj
Browse files

处理kDLGPU为kDLROCM

parent c454d419
...@@ -136,7 +136,7 @@ struct COOMatrix { ...@@ -136,7 +136,7 @@ struct COOMatrix {
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
......
...@@ -129,7 +129,7 @@ struct CSRMatrix { ...@@ -129,7 +129,7 @@ struct CSRMatrix {
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLROCM: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
......
...@@ -46,8 +46,8 @@ ...@@ -46,8 +46,8 @@
if ((val) == kDLCPU) { \ if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \ constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val) == kDLGPU) { \ } else if ((val) == kDLROCM) { \
constexpr auto XPU = kDLGPU; \ constexpr auto XPU = kDLROCM; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \ LOG(FATAL) << "Operator " << (op) << " does not support " \
......
...@@ -173,7 +173,7 @@ class NDArray { ...@@ -173,7 +173,7 @@ class NDArray {
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLROCM: invalid, will throw an error.
*/ */
inline void PinMemory_(); inline void PinMemory_();
/*! /*!
...@@ -303,7 +303,7 @@ class NDArray { ...@@ -303,7 +303,7 @@ class NDArray {
* Behavior depends on the current context, * Behavior depends on the current context,
* kDLCPU: will be pinned; * kDLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDLROCM: invalid, will throw an error.
*/ */
DGL_DLL static void PinContainer(Container* ptr); DGL_DLL static void PinContainer(Container* ptr);
...@@ -600,7 +600,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -600,7 +600,7 @@ inline const char* TypeCode2Str(int type_code) {
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) { inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
switch (device_type) { switch (device_type) {
case kDLCPU: return "cpu"; case kDLCPU: return "cpu";
case kDLGPU: return "cuda"; case kDLROCM: return "cuda";
case kDLCPUPinned: return "cpu_pinned"; case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl"; case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan"; case kDLVulkan: return "vulkan";
......
...@@ -89,12 +89,18 @@ def device_id(ctx): ...@@ -89,12 +89,18 @@ def device_id(ctx):
else: else:
return ctx.index return ctx.index
__devtype_th_map = {
1: "cpu",
2: "cuda", # cuda device
10: "cuda" # rocm device
}
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
dev_type = dglctx.device_type dev_type = dglctx.device_type
if dev_type == 1: if dev_type in __devtype_th_map:
return th.device('cpu') th_type = __devtype_th_map[dev_type]
elif dev_type == 2: return th.device(th_type, dglctx.device_id)
return th.device('cuda', dglctx.device_id)
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError('Unsupported DGL device context:', dglctx)
......
...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret; return ret;
} }
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool); template IdArray CumSum<kDLROCM, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool); template IdArray CumSum<kDLROCM, int64_t>(IdArray, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -51,18 +51,18 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -51,18 +51,18 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret; return ret;
} }
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16 #ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, __half, int64_t>(NDArray, IdArray);
#endif #endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDLROCM, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType> template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
...@@ -84,15 +84,15 @@ DType IndexSelect(NDArray array, int64_t index) { ...@@ -84,15 +84,15 @@ DType IndexSelect(NDArray array, int64_t index) {
return reinterpret_cast<DType&>(ret); return reinterpret_cast<DType&>(ret);
} }
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDLROCM, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDLROCM, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index); template uint32_t IndexSelect<kDLROCM, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDLROCM, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16 #ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index); template __half IndexSelect<kDLROCM, __half>(NDArray array, int64_t index);
#endif #endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index); template float IndexSelect<kDLROCM, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index); template double IndexSelect<kDLROCM, double>(NDArray array, int64_t index);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) { ...@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0); return ret.CreateView({num_nonzeros}, ret->dtype, 0);
} }
template IdArray NonZero<kDLGPU, int32_t>(IdArray); template IdArray NonZero<kDLROCM, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray); template IdArray NonZero<kDLROCM, int64_t>(IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -45,28 +45,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -45,28 +45,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
...@@ -95,28 +95,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -95,28 +95,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
...@@ -146,28 +146,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -146,28 +146,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDLROCM, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel( __global__ void _UnaryElewiseKernel(
...@@ -195,8 +195,8 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -195,8 +195,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret; return ret;
} }
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDLROCM, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDLROCM, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
...@@ -223,13 +223,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) { ...@@ -223,13 +223,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret; return ret;
} }
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template IdArray Full<kDLROCM, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template IdArray Full<kDLROCM, int64_t>(int64_t val, int64_t length, DLContext ctx);
#ifdef USE_FP16 #ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx); template IdArray Full<kDLROCM, __half>(__half val, int64_t length, DLContext ctx);
#endif #endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx); template IdArray Full<kDLROCM, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx); template IdArray Full<kDLROCM, double>(double val, int64_t length, DLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
...@@ -261,8 +261,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) { ...@@ -261,8 +261,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret; return ret;
} }
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext); template IdArray Range<kDLROCM, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext); template IdArray Range<kDLROCM, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// Relabel_ ////////////////////////////// ///////////////////////////// Relabel_ //////////////////////////////
...@@ -339,8 +339,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -339,8 +339,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes; return induced_nodes;
} }
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDLROCM, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDLROCM, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits ///////////////////////////// ///////////////////////////// AsNumBits /////////////////////////////
...@@ -375,8 +375,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -375,8 +375,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
} }
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDLROCM, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDLROCM, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -38,20 +38,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -38,20 +38,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd); idx, val, len, outd);
} }
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, __half, int32_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16 #ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, __half, int64_t>(IdArray, NDArray, NDArray);
#endif #endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDLROCM, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDLROCM, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDLROCM, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -22,7 +22,7 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -22,7 +22,7 @@ CSRMatrix COOToCSR(COOMatrix coo) {
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -101,7 +101,7 @@ __global__ void _SortedSearchKernelUpperBound( ...@@ -101,7 +101,7 @@ __global__ void _SortedSearchKernelUpperBound(
} }
template <> template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits; const auto nbits = coo.row->dtype.bits;
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
...@@ -134,8 +134,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { ...@@ -134,8 +134,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted); indptr, coo.col, coo.data, col_sorted);
} }
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -132,8 +132,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -132,8 +132,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
} }
} }
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDLROCM, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDLROCM, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted ///////////////////////////// ///////////////////////////// COOIsSorted /////////////////////////////
...@@ -181,8 +181,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -181,8 +181,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted}; return {row_sorted, col_sorted};
} }
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDLROCM, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDLROCM, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -22,7 +22,7 @@ COOMatrix CSRToCOO(CSRMatrix csr) { ...@@ -22,7 +22,7 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
} }
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDLROCM, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -78,7 +78,7 @@ __global__ void _RepeatKernel( ...@@ -78,7 +78,7 @@ __global__ void _RepeatKernel(
} }
template <> template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDLROCM, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
...@@ -100,8 +100,8 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) { ...@@ -100,8 +100,8 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true, csr.sorted); true, csr.sorted);
} }
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDLROCM, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDLROCM, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
...@@ -110,8 +110,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -110,8 +110,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
} }
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDLROCM, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr); COOMatrix coo = CSRToCOO<kDLROCM, int32_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data))
return coo; return coo;
...@@ -157,8 +157,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) { ...@@ -157,8 +157,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
} }
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDLROCM, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr); COOMatrix coo = CSRToCOO<kDLROCM, int64_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data))
return coo; return coo;
const auto& sorted = Sort(coo.data); const auto& sorted = Sort(coo.data);
...@@ -174,8 +174,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) { ...@@ -174,8 +174,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return coo; return coo;
} }
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDLROCM, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDLROCM, int64_t>(CSRMatrix csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -53,24 +53,24 @@ NDArray CSRGetData( ...@@ -53,24 +53,24 @@ NDArray CSRGetData(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>( template NDArray CSRGetData<kDLROCM, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>( template NDArray CSRGetData<kDLROCM, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif #endif
template NDArray CSRGetData<kDLGPU, int32_t, float>( template NDArray CSRGetData<kDLROCM, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>( template NDArray CSRGetData<kDLROCM, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>( template NDArray CSRGetData<kDLROCM, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>( template NDArray CSRGetData<kDLROCM, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>( template NDArray CSRGetData<kDLROCM, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>( template NDArray CSRGetData<kDLROCM, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
......
...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif #endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten } // namespace aten
......
...@@ -54,8 +54,8 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -54,8 +54,8 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret; return ret;
} }
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr); template bool CSRIsSorted<kDLROCM, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr); template bool CSRIsSorted<kDLROCM, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) { void CSRSort_(CSRMatrix* csr) {
...@@ -63,7 +63,7 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -63,7 +63,7 @@ void CSRSort_(CSRMatrix* csr) {
} }
template <> template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
...@@ -109,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) { ...@@ -109,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
} }
template <> template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr) {
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
...@@ -148,8 +148,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { ...@@ -148,8 +148,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr); template void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr); template void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
} }
#ifdef USE_FP16 #ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif #endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLROCM, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten } // namespace aten
......
...@@ -20,7 +20,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -20,7 +20,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
} }
template <> template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDLROCM, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentCUDAStream(); hipStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) { ...@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
} }
template <> template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDLROCM, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false))); return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
} }
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDLROCM, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDLROCM, int64_t>(CSRMatrix csr);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -156,8 +156,8 @@ FilterRef CreateSetFilter(IdArray set) { ...@@ -156,8 +156,8 @@ FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set)); return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
} }
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set); template FilterRef CreateSetFilter<kDLROCM, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set); template FilterRef CreateSetFilter<kDLROCM, int64_t>(IdArray set);
} // namespace array } // namespace array
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment