Commit 351f7588 authored by lisj's avatar lisj
Browse files

处理kDLROCM有关的函数注册

parent aaaecbc9
......@@ -600,7 +600,7 @@ inline const char* TypeCode2Str(int type_code) {
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
switch (device_type) {
case kDLCPU: return "cpu";
case kDLROCM: return "cuda";
case kDLGPU: return "cuda";
case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
......
......@@ -27,7 +27,7 @@ namespace runtime {
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLROCM: return "gpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
......
......@@ -335,5 +335,11 @@ DGL_REGISTER_GLOBAL("device_api.gpu")
*rv = static_cast<void*>(ptr);
});
DGL_REGISTER_GLOBAL("device_api.rocm")
.set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment