Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
50c27a8e
Commit
50c27a8e
authored
Sep 30, 2024
by
sangwzh
Browse files
update device type code
parent
910d6a98
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
9 deletions
+10
-9
include/dgl/aten/macro.h
include/dgl/aten/macro.h
+2
-2
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+2
-2
python/dgl/_ffi/runtime_ctypes.py
python/dgl/_ffi/runtime_ctypes.py
+4
-4
python/dgl/ndarray.py
python/dgl/ndarray.py
+2
-1
No files found.
include/dgl/aten/macro.h
View file @
50c27a8e
...
@@ -177,11 +177,11 @@
...
@@ -177,11 +177,11 @@
typedef double FloatType; \
typedef double FloatType; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
} else if ( \
} else if ( \
(XPU == kDGLCUDA ||
&&
XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
(XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
typedef __half FloatType; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
} else if ( \
} else if ( \
(XPU == kDGLCUDA ||
&&
XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
(XPU == kDGLCUDA || XPU == kDGLROCM) && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
...
...
include/dgl/runtime/c_runtime_api.h
View file @
50c27a8e
...
@@ -55,9 +55,9 @@ typedef enum {
...
@@ -55,9 +55,9 @@ typedef enum {
/** @brief CPU device */
/** @brief CPU device */
kDGLCPU
=
1
,
kDGLCPU
=
1
,
/** @brief CUDA GPU device */
/** @brief CUDA GPU device */
kDGLCUDA
=
2
,
kDGLCUDA
=
10
,
kDGLROCM
=
2
,
// add more devices once supported
// add more devices once supported
kDGLROCM
=
10
,
}
DGLDeviceType
;
}
DGLDeviceType
;
/**
/**
...
...
python/dgl/_ffi/runtime_ctypes.py
View file @
50c27a8e
...
@@ -131,9 +131,9 @@ class DGLContext(ctypes.Structure):
...
@@ -131,9 +131,9 @@ class DGLContext(ctypes.Structure):
"llvm"
:
1
,
"llvm"
:
1
,
"stackvm"
:
1
,
"stackvm"
:
1
,
"cpu"
:
1
,
"cpu"
:
1
,
"gpu"
:
2
,
"gpu"
:
10
,
"cuda"
:
2
,
"cuda"
:
10
,
"nvptx"
:
2
,
"nvptx"
:
10
,
"cl"
:
4
,
"cl"
:
4
,
"opencl"
:
4
,
"opencl"
:
4
,
"aocl"
:
5
,
"aocl"
:
5
,
...
@@ -142,7 +142,7 @@ class DGLContext(ctypes.Structure):
...
@@ -142,7 +142,7 @@ class DGLContext(ctypes.Structure):
"vulkan"
:
7
,
"vulkan"
:
7
,
"metal"
:
8
,
"metal"
:
8
,
"vpi"
:
9
,
"vpi"
:
9
,
"rocm"
:
2
,
"rocm"
:
10
,
"opengl"
:
11
,
"opengl"
:
11
,
"ext_dev"
:
12
,
"ext_dev"
:
12
,
}
}
...
...
python/dgl/ndarray.py
View file @
50c27a8e
...
@@ -80,7 +80,8 @@ def gpu(dev_id=0):
...
@@ -80,7 +80,8 @@ def gpu(dev_id=0):
ctx : DGLContext
ctx : DGLContext
The created context
The created context
"""
"""
return
DGLContext
(
2
,
dev_id
)
# device type for dcu is 10, nv is 2
return
DGLContext
(
10
,
dev_id
)
def
array
(
arr
,
ctx
=
cpu
(
0
)):
def
array
(
arr
,
ctx
=
cpu
(
0
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment