Commit 50c27a8e authored by sangwzh's avatar sangwzh
Browse files

update device type code

parent 910d6a98
......@@ -177,11 +177,11 @@
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
(XPU == kDGLCUDA || && XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
(XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
(XPU == kDGLCUDA || && XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
(XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
......
......@@ -55,9 +55,9 @@ typedef enum {
/** @brief CPU device */
kDGLCPU = 1,
/** @brief CUDA GPU device */
kDGLCUDA = 2,
kDGLCUDA = 10,
kDGLROCM = 2,
// add more devices once supported
kDGLROCM = 10,
} DGLDeviceType;
/**
......
......@@ -131,9 +131,9 @@ class DGLContext(ctypes.Structure):
"llvm": 1,
"stackvm": 1,
"cpu": 1,
"gpu": 2,
"cuda": 2,
"nvptx": 2,
"gpu": 10,
"cuda": 10,
"nvptx": 10,
"cl": 4,
"opencl": 4,
"aocl": 5,
......@@ -142,7 +142,7 @@ class DGLContext(ctypes.Structure):
"vulkan": 7,
"metal": 8,
"vpi": 9,
"rocm": 2,
"rocm": 10,
"opengl": 11,
"ext_dev": 12,
}
......
......@@ -80,7 +80,8 @@ def gpu(dev_id=0):
ctx : DGLContext
The created context
"""
return DGLContext(2, dev_id)
# device type for dcu is 10, nv is 2
return DGLContext(10, dev_id)
def array(arr, ctx=cpu(0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment