Commit 50c27a8e authored by sangwzh's avatar sangwzh
Browse files

update device type code

parent 910d6a98
...@@ -177,11 +177,11 @@ ...@@ -177,11 +177,11 @@
typedef double FloatType; \ typedef double FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
(XPU == kDGLCUDA || && XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \ typedef __half FloatType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ( \ } else if ( \
(XPU == kDGLCUDA || && XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \ (XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \ LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \ } else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \ XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
......
...@@ -55,9 +55,9 @@ typedef enum { ...@@ -55,9 +55,9 @@ typedef enum {
/** @brief CPU device */ /** @brief CPU device */
kDGLCPU = 1, kDGLCPU = 1,
/** @brief CUDA GPU device */ /** @brief CUDA GPU device */
kDGLCUDA = 2, kDGLCUDA = 10,
kDGLROCM = 2,
// add more devices once supported // add more devices once supported
kDGLROCM = 10,
} DGLDeviceType; } DGLDeviceType;
/** /**
......
...@@ -131,9 +131,9 @@ class DGLContext(ctypes.Structure): ...@@ -131,9 +131,9 @@ class DGLContext(ctypes.Structure):
"llvm": 1, "llvm": 1,
"stackvm": 1, "stackvm": 1,
"cpu": 1, "cpu": 1,
"gpu": 2, "gpu": 10,
"cuda": 2, "cuda": 10,
"nvptx": 2, "nvptx": 10,
"cl": 4, "cl": 4,
"opencl": 4, "opencl": 4,
"aocl": 5, "aocl": 5,
...@@ -142,7 +142,7 @@ class DGLContext(ctypes.Structure): ...@@ -142,7 +142,7 @@ class DGLContext(ctypes.Structure):
"vulkan": 7, "vulkan": 7,
"metal": 8, "metal": 8,
"vpi": 9, "vpi": 9,
"rocm": 2, "rocm": 10,
"opengl": 11, "opengl": 11,
"ext_dev": 12, "ext_dev": 12,
} }
......
...@@ -80,7 +80,8 @@ def gpu(dev_id=0): ...@@ -80,7 +80,8 @@ def gpu(dev_id=0):
ctx : DGLContext ctx : DGLContext
The created context The created context
""" """
return DGLContext(2, dev_id) # device type for dcu is 10, nv is 2
return DGLContext(10, dev_id)
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment