Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
aaaecbc9
Commit
aaaecbc9
authored
May 12, 2023
by
lisj
Browse files
处理kDLGPU为kDLROCM
parent
c454d419
Changes
54
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
179 additions
and
173 deletions
+179
-173
include/dgl/aten/coo.h
include/dgl/aten/coo.h
+1
-1
include/dgl/aten/csr.h
include/dgl/aten/csr.h
+1
-1
include/dgl/aten/macro.h
include/dgl/aten/macro.h
+2
-2
include/dgl/runtime/ndarray.h
include/dgl/runtime/ndarray.h
+3
-3
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+10
-4
src/array/cuda/array_cumsum.cu
src/array/cuda/array_cumsum.cu
+2
-2
src/array/cuda/array_index_select.cu
src/array/cuda/array_index_select.cu
+17
-17
src/array/cuda/array_nonzero.cu
src/array/cuda/array_nonzero.cu
+2
-2
src/array/cuda/array_op_impl.cu
src/array/cuda/array_op_impl.cu
+79
-79
src/array/cuda/array_scatter.cu
src/array/cuda/array_scatter.cu
+10
-10
src/array/cuda/array_sort.cu
src/array/cuda/array_sort.cu
+2
-2
src/array/cuda/coo2csr.cu
src/array/cuda/coo2csr.cu
+4
-4
src/array/cuda/coo_sort.cu
src/array/cuda/coo_sort.cu
+4
-4
src/array/cuda/csr2coo.cu
src/array/cuda/csr2coo.cu
+10
-10
src/array/cuda/csr_get_data.cu
src/array/cuda/csr_get_data.cu
+8
-8
src/array/cuda/csr_mm.cu
src/array/cuda/csr_mm.cu
+6
-6
src/array/cuda/csr_sort.cu
src/array/cuda/csr_sort.cu
+6
-6
src/array/cuda/csr_sum.cu
src/array/cuda/csr_sum.cu
+6
-6
src/array/cuda/csr_transpose.cc
src/array/cuda/csr_transpose.cc
+4
-4
src/array/cuda/cuda_filter.cu
src/array/cuda/cuda_filter.cu
+2
-2
No files found.
include/dgl/aten/coo.h
View file @
aaaecbc9
...
...
@@ -136,7 +136,7 @@ struct COOMatrix {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDL
GPU
: invalid, will throw an error.
* kDL
ROCM
: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline
void
PinMemory_
()
{
...
...
include/dgl/aten/csr.h
View file @
aaaecbc9
...
...
@@ -129,7 +129,7 @@ struct CSRMatrix {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDL
GPU
: invalid, will throw an error.
* kDL
ROCM
: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline
void
PinMemory_
()
{
...
...
include/dgl/aten/macro.h
View file @
aaaecbc9
...
...
@@ -46,8 +46,8 @@
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDL
GPU
) { \
constexpr auto XPU = kDL
GPU
; \
} else if ((val) == kDL
ROCM
) { \
constexpr auto XPU = kDL
ROCM
; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
...
...
include/dgl/runtime/ndarray.h
View file @
aaaecbc9
...
...
@@ -173,7 +173,7 @@ class NDArray {
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDL
GPU
: invalid, will throw an error.
* kDL
ROCM
: invalid, will throw an error.
*/
inline
void
PinMemory_
();
/*!
...
...
@@ -303,7 +303,7 @@ class NDArray {
* Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDL
GPU
: invalid, will throw an error.
* kDL
ROCM
: invalid, will throw an error.
*/
DGL_DLL
static
void
PinContainer
(
Container
*
ptr
);
...
...
@@ -600,7 +600,7 @@ inline const char* TypeCode2Str(int type_code) {
inline
const
char
*
DeviceTypeCode2Str
(
DLDeviceType
device_type
)
{
switch
(
device_type
)
{
case
kDLCPU
:
return
"cpu"
;
case
kDL
GPU
:
return
"cuda"
;
case
kDL
ROCM
:
return
"cuda"
;
case
kDLCPUPinned
:
return
"cpu_pinned"
;
case
kDLOpenCL
:
return
"opencl"
;
case
kDLVulkan
:
return
"vulkan"
;
...
...
python/dgl/backend/pytorch/tensor.py
View file @
aaaecbc9
...
...
@@ -89,12 +89,18 @@ def device_id(ctx):
else
:
return
ctx
.
index
__devtype_th_map
=
{
1
:
"cpu"
,
2
:
"cuda"
,
# cuda device
10
:
"cuda"
# rocm device
}
def
to_backend_ctx
(
dglctx
):
dev_type
=
dglctx
.
device_type
if
dev_type
==
1
:
return
th
.
device
(
'cpu'
)
elif
dev_type
==
2
:
return
th
.
device
(
'cuda'
,
dglctx
.
device_id
)
if
dev_type
in
__devtype_th_map
:
th_type
=
__devtype_th_map
[
dev_type
]
return
th
.
device
(
th_type
,
dglctx
.
device_id
)
else
:
raise
ValueError
(
'Unsupported DGL device context:'
,
dglctx
)
...
...
src/array/cuda/array_cumsum.cu
View file @
aaaecbc9
...
...
@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return
ret
;
}
template
IdArray
CumSum
<
kDL
GPU
,
int32_t
>(
IdArray
,
bool
);
template
IdArray
CumSum
<
kDL
GPU
,
int64_t
>(
IdArray
,
bool
);
template
IdArray
CumSum
<
kDL
ROCM
,
int32_t
>(
IdArray
,
bool
);
template
IdArray
CumSum
<
kDL
ROCM
,
int64_t
>(
IdArray
,
bool
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/array_index_select.cu
View file @
aaaecbc9
...
...
@@ -51,18 +51,18 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return
ret
;
}
template
NDArray
IndexSelect
<
kDL
GPU
,
int32_t
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
int32_t
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
int64_t
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
int64_t
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
int32_t
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
int32_t
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
int64_t
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
int64_t
,
int64_t
>(
NDArray
,
IdArray
);
#ifdef USE_FP16
template
NDArray
IndexSelect
<
kDL
GPU
,
__half
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
__half
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
__half
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
__half
,
int64_t
>(
NDArray
,
IdArray
);
#endif
template
NDArray
IndexSelect
<
kDL
GPU
,
float
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
float
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
double
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
GPU
,
double
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
float
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
float
,
int64_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
double
,
int32_t
>(
NDArray
,
IdArray
);
template
NDArray
IndexSelect
<
kDL
ROCM
,
double
,
int64_t
>(
NDArray
,
IdArray
);
template
<
DLDeviceType
XPU
,
typename
DType
>
DType
IndexSelect
(
NDArray
array
,
int64_t
index
)
{
...
...
@@ -84,15 +84,15 @@ DType IndexSelect(NDArray array, int64_t index) {
return
reinterpret_cast
<
DType
&>
(
ret
);
}
template
int32_t
IndexSelect
<
kDL
GPU
,
int32_t
>(
NDArray
array
,
int64_t
index
);
template
int64_t
IndexSelect
<
kDL
GPU
,
int64_t
>(
NDArray
array
,
int64_t
index
);
template
uint32_t
IndexSelect
<
kDL
GPU
,
uint32_t
>(
NDArray
array
,
int64_t
index
);
template
uint64_t
IndexSelect
<
kDL
GPU
,
uint64_t
>(
NDArray
array
,
int64_t
index
);
template
int32_t
IndexSelect
<
kDL
ROCM
,
int32_t
>(
NDArray
array
,
int64_t
index
);
template
int64_t
IndexSelect
<
kDL
ROCM
,
int64_t
>(
NDArray
array
,
int64_t
index
);
template
uint32_t
IndexSelect
<
kDL
ROCM
,
uint32_t
>(
NDArray
array
,
int64_t
index
);
template
uint64_t
IndexSelect
<
kDL
ROCM
,
uint64_t
>(
NDArray
array
,
int64_t
index
);
#ifdef USE_FP16
template
__half
IndexSelect
<
kDL
GPU
,
__half
>(
NDArray
array
,
int64_t
index
);
template
__half
IndexSelect
<
kDL
ROCM
,
__half
>(
NDArray
array
,
int64_t
index
);
#endif
template
float
IndexSelect
<
kDL
GPU
,
float
>(
NDArray
array
,
int64_t
index
);
template
double
IndexSelect
<
kDL
GPU
,
double
>(
NDArray
array
,
int64_t
index
);
template
float
IndexSelect
<
kDL
ROCM
,
float
>(
NDArray
array
,
int64_t
index
);
template
double
IndexSelect
<
kDL
ROCM
,
double
>(
NDArray
array
,
int64_t
index
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/array_nonzero.cu
View file @
aaaecbc9
...
...
@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return
ret
.
CreateView
({
num_nonzeros
},
ret
->
dtype
,
0
);
}
template
IdArray
NonZero
<
kDL
GPU
,
int32_t
>(
IdArray
);
template
IdArray
NonZero
<
kDL
GPU
,
int64_t
>(
IdArray
);
template
IdArray
NonZero
<
kDL
ROCM
,
int32_t
>(
IdArray
);
template
IdArray
NonZero
<
kDL
ROCM
,
int64_t
>(
IdArray
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/array_op_impl.cu
View file @
aaaecbc9
...
...
@@ -45,28 +45,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return
ret
;
}
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Add
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Sub
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mul
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Div
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mod
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
EQ
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
NE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Add
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Sub
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mul
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Div
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mod
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
EQ
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
NE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Add
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Sub
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mul
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Div
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mod
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
EQ
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
NE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Add
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Sub
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mul
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Div
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mod
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LT
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LE
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
EQ
>(
IdArray
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
NE
>(
IdArray
lhs
,
IdArray
rhs
);
template
<
typename
IdType
,
typename
Op
>
...
...
@@ -95,28 +95,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return
ret
;
}
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Add
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Sub
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mul
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Div
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mod
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GT
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LT
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
EQ
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
NE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Add
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Sub
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mul
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Div
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mod
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GT
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LT
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GE
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LE
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
EQ
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
NE
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Add
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Sub
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mul
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Div
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mod
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GT
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LT
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
EQ
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
NE
>(
IdArray
lhs
,
int32_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Add
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Sub
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mul
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Div
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mod
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GT
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LT
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GE
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LE
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
EQ
>(
IdArray
lhs
,
int64_t
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
NE
>(
IdArray
lhs
,
int64_t
rhs
);
...
...
@@ -146,28 +146,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return
ret
;
}
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Add
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Sub
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mul
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Div
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Mod
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GT
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LT
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
GE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
LE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
EQ
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
NE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Add
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Sub
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mul
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Div
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Mod
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GT
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LT
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
GE
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
LE
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
EQ
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
NE
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Add
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Sub
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mul
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Div
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Mod
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GT
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LT
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
GE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
LE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
EQ
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
NE
>(
int32_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Add
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Sub
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mul
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Div
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Mod
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GT
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LT
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
GE
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
LE
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
EQ
>(
int64_t
lhs
,
IdArray
rhs
);
template
IdArray
BinaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
NE
>(
int64_t
lhs
,
IdArray
rhs
);
template
<
typename
IdType
,
typename
Op
>
__global__
void
_UnaryElewiseKernel
(
...
...
@@ -195,8 +195,8 @@ IdArray UnaryElewise(IdArray lhs) {
return
ret
;
}
template
IdArray
UnaryElewise
<
kDL
GPU
,
int32_t
,
arith
::
Neg
>(
IdArray
lhs
);
template
IdArray
UnaryElewise
<
kDL
GPU
,
int64_t
,
arith
::
Neg
>(
IdArray
lhs
);
template
IdArray
UnaryElewise
<
kDL
ROCM
,
int32_t
,
arith
::
Neg
>(
IdArray
lhs
);
template
IdArray
UnaryElewise
<
kDL
ROCM
,
int64_t
,
arith
::
Neg
>(
IdArray
lhs
);
///////////////////////////// Full /////////////////////////////
...
...
@@ -223,13 +223,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return
ret
;
}
template
IdArray
Full
<
kDL
GPU
,
int32_t
>(
int32_t
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
GPU
,
int64_t
>(
int64_t
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
ROCM
,
int32_t
>(
int32_t
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
ROCM
,
int64_t
>(
int64_t
val
,
int64_t
length
,
DLContext
ctx
);
#ifdef USE_FP16
template
IdArray
Full
<
kDL
GPU
,
__half
>(
__half
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
ROCM
,
__half
>(
__half
val
,
int64_t
length
,
DLContext
ctx
);
#endif
template
IdArray
Full
<
kDL
GPU
,
float
>(
float
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
GPU
,
double
>(
double
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
ROCM
,
float
>(
float
val
,
int64_t
length
,
DLContext
ctx
);
template
IdArray
Full
<
kDL
ROCM
,
double
>(
double
val
,
int64_t
length
,
DLContext
ctx
);
///////////////////////////// Range /////////////////////////////
...
...
@@ -261,8 +261,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return
ret
;
}
template
IdArray
Range
<
kDL
GPU
,
int32_t
>(
int32_t
,
int32_t
,
DLContext
);
template
IdArray
Range
<
kDL
GPU
,
int64_t
>(
int64_t
,
int64_t
,
DLContext
);
template
IdArray
Range
<
kDL
ROCM
,
int32_t
>(
int32_t
,
int32_t
,
DLContext
);
template
IdArray
Range
<
kDL
ROCM
,
int64_t
>(
int64_t
,
int64_t
,
DLContext
);
///////////////////////////// Relabel_ //////////////////////////////
...
...
@@ -339,8 +339,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return
induced_nodes
;
}
template
IdArray
Relabel_
<
kDL
GPU
,
int32_t
>(
const
std
::
vector
<
IdArray
>&
arrays
);
template
IdArray
Relabel_
<
kDL
GPU
,
int64_t
>(
const
std
::
vector
<
IdArray
>&
arrays
);
template
IdArray
Relabel_
<
kDL
ROCM
,
int32_t
>(
const
std
::
vector
<
IdArray
>&
arrays
);
template
IdArray
Relabel_
<
kDL
ROCM
,
int64_t
>(
const
std
::
vector
<
IdArray
>&
arrays
);
///////////////////////////// AsNumBits /////////////////////////////
...
...
@@ -375,8 +375,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
}
template
IdArray
AsNumBits
<
kDL
GPU
,
int32_t
>(
IdArray
arr
,
uint8_t
bits
);
template
IdArray
AsNumBits
<
kDL
GPU
,
int64_t
>(
IdArray
arr
,
uint8_t
bits
);
template
IdArray
AsNumBits
<
kDL
ROCM
,
int32_t
>(
IdArray
arr
,
uint8_t
bits
);
template
IdArray
AsNumBits
<
kDL
ROCM
,
int64_t
>(
IdArray
arr
,
uint8_t
bits
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/array_scatter.cu
View file @
aaaecbc9
...
...
@@ -38,20 +38,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx
,
val
,
len
,
outd
);
}
template
void
Scatter_
<
kDL
GPU
,
int32_t
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
GPU
,
int64_t
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
int32_t
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
int64_t
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
#ifdef USE_FP16
template
void
Scatter_
<
kDL
GPU
,
__half
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
__half
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
#endif
template
void
Scatter_
<
kDL
GPU
,
float
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
GPU
,
double
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
GPU
,
int32_t
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
GPU
,
int64_t
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
float
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
double
,
int32_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
int32_t
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
int64_t
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
#ifdef USE_FP16
template
void
Scatter_
<
kDL
GPU
,
__half
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
__half
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
#endif
template
void
Scatter_
<
kDL
GPU
,
float
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
GPU
,
double
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
float
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
template
void
Scatter_
<
kDL
ROCM
,
double
,
int64_t
>(
IdArray
,
NDArray
,
NDArray
);
};
// namespace impl
};
// namespace aten
...
...
src/array/cuda/array_sort.cu
View file @
aaaecbc9
...
...
@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return
std
::
make_pair
(
sorted_array
,
sorted_idx
);
}
template
std
::
pair
<
IdArray
,
IdArray
>
Sort
<
kDL
GPU
,
int32_t
>
(
IdArray
,
int
num_bits
);
template
std
::
pair
<
IdArray
,
IdArray
>
Sort
<
kDL
GPU
,
int64_t
>
(
IdArray
,
int
num_bits
);
template
std
::
pair
<
IdArray
,
IdArray
>
Sort
<
kDL
ROCM
,
int32_t
>
(
IdArray
,
int
num_bits
);
template
std
::
pair
<
IdArray
,
IdArray
>
Sort
<
kDL
ROCM
,
int64_t
>
(
IdArray
,
int
num_bits
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/coo2csr.cu
View file @
aaaecbc9
...
...
@@ -22,7 +22,7 @@ CSRMatrix COOToCSR(COOMatrix coo) {
}
template
<
>
CSRMatrix
COOToCSR
<
kDL
GPU
,
int32_t
>
(
COOMatrix
coo
)
{
CSRMatrix
COOToCSR
<
kDL
ROCM
,
int32_t
>
(
COOMatrix
coo
)
{
auto
*
thr_entry
=
runtime
::
CUDAThreadEntry
::
ThreadLocal
();
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
// allocate cusparse handle if needed
...
...
@@ -101,7 +101,7 @@ __global__ void _SortedSearchKernelUpperBound(
}
template
<
>
CSRMatrix
COOToCSR
<
kDL
GPU
,
int64_t
>
(
COOMatrix
coo
)
{
CSRMatrix
COOToCSR
<
kDL
ROCM
,
int64_t
>
(
COOMatrix
coo
)
{
const
auto
&
ctx
=
coo
.
row
->
ctx
;
const
auto
nbits
=
coo
.
row
->
dtype
.
bits
;
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
...
...
@@ -134,8 +134,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr
,
coo
.
col
,
coo
.
data
,
col_sorted
);
}
template
CSRMatrix
COOToCSR
<
kDL
GPU
,
int32_t
>(
COOMatrix
coo
);
template
CSRMatrix
COOToCSR
<
kDL
GPU
,
int64_t
>(
COOMatrix
coo
);
template
CSRMatrix
COOToCSR
<
kDL
ROCM
,
int32_t
>(
COOMatrix
coo
);
template
CSRMatrix
COOToCSR
<
kDL
ROCM
,
int64_t
>(
COOMatrix
coo
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/coo_sort.cu
View file @
aaaecbc9
...
...
@@ -132,8 +132,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
}
}
template
void
COOSort_
<
kDL
GPU
,
int32_t
>(
COOMatrix
*
coo
,
bool
sort_column
);
template
void
COOSort_
<
kDL
GPU
,
int64_t
>(
COOMatrix
*
coo
,
bool
sort_column
);
template
void
COOSort_
<
kDL
ROCM
,
int32_t
>(
COOMatrix
*
coo
,
bool
sort_column
);
template
void
COOSort_
<
kDL
ROCM
,
int64_t
>(
COOMatrix
*
coo
,
bool
sort_column
);
///////////////////////////// COOIsSorted /////////////////////////////
...
...
@@ -181,8 +181,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return
{
row_sorted
,
col_sorted
};
}
template
std
::
pair
<
bool
,
bool
>
COOIsSorted
<
kDL
GPU
,
int32_t
>
(
COOMatrix
coo
);
template
std
::
pair
<
bool
,
bool
>
COOIsSorted
<
kDL
GPU
,
int64_t
>
(
COOMatrix
coo
);
template
std
::
pair
<
bool
,
bool
>
COOIsSorted
<
kDL
ROCM
,
int32_t
>
(
COOMatrix
coo
);
template
std
::
pair
<
bool
,
bool
>
COOIsSorted
<
kDL
ROCM
,
int64_t
>
(
COOMatrix
coo
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/csr2coo.cu
View file @
aaaecbc9
...
...
@@ -22,7 +22,7 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
}
template
<
>
COOMatrix
CSRToCOO
<
kDL
GPU
,
int32_t
>
(
CSRMatrix
csr
)
{
COOMatrix
CSRToCOO
<
kDL
ROCM
,
int32_t
>
(
CSRMatrix
csr
)
{
auto
*
thr_entry
=
runtime
::
CUDAThreadEntry
::
ThreadLocal
();
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
// allocate cusparse handle if needed
...
...
@@ -78,7 +78,7 @@ __global__ void _RepeatKernel(
}
template
<
>
COOMatrix
CSRToCOO
<
kDL
GPU
,
int64_t
>
(
CSRMatrix
csr
)
{
COOMatrix
CSRToCOO
<
kDL
ROCM
,
int64_t
>
(
CSRMatrix
csr
)
{
const
auto
&
ctx
=
csr
.
indptr
->
ctx
;
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
...
...
@@ -100,8 +100,8 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true
,
csr
.
sorted
);
}
template
COOMatrix
CSRToCOO
<
kDL
GPU
,
int32_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOO
<
kDL
GPU
,
int64_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOO
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOO
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
);
template
<
DLDeviceType
XPU
,
typename
IdType
>
COOMatrix
CSRToCOODataAsOrder
(
CSRMatrix
csr
)
{
...
...
@@ -110,8 +110,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
}
template
<
>
COOMatrix
CSRToCOODataAsOrder
<
kDL
GPU
,
int32_t
>
(
CSRMatrix
csr
)
{
COOMatrix
coo
=
CSRToCOO
<
kDL
GPU
,
int32_t
>
(
csr
);
COOMatrix
CSRToCOODataAsOrder
<
kDL
ROCM
,
int32_t
>
(
CSRMatrix
csr
)
{
COOMatrix
coo
=
CSRToCOO
<
kDL
ROCM
,
int32_t
>
(
csr
);
if
(
aten
::
IsNullArray
(
coo
.
data
))
return
coo
;
...
...
@@ -157,8 +157,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
}
template
<
>
COOMatrix
CSRToCOODataAsOrder
<
kDL
GPU
,
int64_t
>
(
CSRMatrix
csr
)
{
COOMatrix
coo
=
CSRToCOO
<
kDL
GPU
,
int64_t
>
(
csr
);
COOMatrix
CSRToCOODataAsOrder
<
kDL
ROCM
,
int64_t
>
(
CSRMatrix
csr
)
{
COOMatrix
coo
=
CSRToCOO
<
kDL
ROCM
,
int64_t
>
(
csr
);
if
(
aten
::
IsNullArray
(
coo
.
data
))
return
coo
;
const
auto
&
sorted
=
Sort
(
coo
.
data
);
...
...
@@ -174,8 +174,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return
coo
;
}
template
COOMatrix
CSRToCOODataAsOrder
<
kDL
GPU
,
int32_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOODataAsOrder
<
kDL
GPU
,
int64_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOODataAsOrder
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
);
template
COOMatrix
CSRToCOODataAsOrder
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/csr_get_data.cu
View file @
aaaecbc9
...
...
@@ -53,24 +53,24 @@ NDArray CSRGetData(
}
#ifdef USE_FP16
template
NDArray
CSRGetData
<
kDL
GPU
,
int32_t
,
__half
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int32_t
,
__half
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
__half
filler
);
template
NDArray
CSRGetData
<
kDL
GPU
,
int64_t
,
__half
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int64_t
,
__half
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
__half
filler
);
#endif
template
NDArray
CSRGetData
<
kDL
GPU
,
int32_t
,
float
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int32_t
,
float
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
float
filler
);
template
NDArray
CSRGetData
<
kDL
GPU
,
int64_t
,
float
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int64_t
,
float
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
float
filler
);
template
NDArray
CSRGetData
<
kDL
GPU
,
int32_t
,
double
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int32_t
,
double
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
double
filler
);
template
NDArray
CSRGetData
<
kDL
GPU
,
int64_t
,
double
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int64_t
,
double
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
double
filler
);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template
NDArray
CSRGetData
<
kDL
GPU
,
int32_t
,
int32_t
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int32_t
,
int32_t
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
int32_t
filler
);
template
NDArray
CSRGetData
<
kDL
GPU
,
int64_t
,
int64_t
>(
template
NDArray
CSRGetData
<
kDL
ROCM
,
int64_t
,
int64_t
>(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
,
bool
return_eids
,
NDArray
weights
,
int64_t
filler
);
}
// namespace impl
...
...
src/array/cuda/csr_mm.cu
View file @
aaaecbc9
...
...
@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
#ifdef USE_FP16
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int32_t
,
__half
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int32_t
,
__half
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int64_t
,
__half
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int64_t
,
__half
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
#endif
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int32_t
,
float
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int32_t
,
float
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int64_t
,
float
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int64_t
,
float
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int32_t
,
double
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int32_t
,
double
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
GPU
,
int64_t
,
double
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRMM
<
kDL
ROCM
,
int64_t
,
double
>
(
const
CSRMatrix
&
,
NDArray
,
const
CSRMatrix
&
,
NDArray
);
}
// namespace aten
...
...
src/array/cuda/csr_sort.cu
View file @
aaaecbc9
...
...
@@ -54,8 +54,8 @@ bool CSRIsSorted(CSRMatrix csr) {
return
ret
;
}
template
bool
CSRIsSorted
<
kDL
GPU
,
int32_t
>(
CSRMatrix
csr
);
template
bool
CSRIsSorted
<
kDL
GPU
,
int64_t
>(
CSRMatrix
csr
);
template
bool
CSRIsSorted
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
);
template
bool
CSRIsSorted
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
);
template
<
DLDeviceType
XPU
,
typename
IdType
>
void
CSRSort_
(
CSRMatrix
*
csr
)
{
...
...
@@ -63,7 +63,7 @@ void CSRSort_(CSRMatrix* csr) {
}
template
<
>
void
CSRSort_
<
kDL
GPU
,
int32_t
>
(
CSRMatrix
*
csr
)
{
void
CSRSort_
<
kDL
ROCM
,
int32_t
>
(
CSRMatrix
*
csr
)
{
auto
*
thr_entry
=
runtime
::
CUDAThreadEntry
::
ThreadLocal
();
auto
device
=
runtime
::
DeviceAPI
::
Get
(
csr
->
indptr
->
ctx
);
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
...
...
@@ -109,7 +109,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
}
template
<
>
void
CSRSort_
<
kDL
GPU
,
int64_t
>
(
CSRMatrix
*
csr
)
{
void
CSRSort_
<
kDL
ROCM
,
int64_t
>
(
CSRMatrix
*
csr
)
{
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
auto
device
=
runtime
::
DeviceAPI
::
Get
(
csr
->
indptr
->
ctx
);
...
...
@@ -148,8 +148,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device
->
FreeWorkspace
(
ctx
,
workspace
);
}
template
void
CSRSort_
<
kDL
GPU
,
int32_t
>(
CSRMatrix
*
csr
);
template
void
CSRSort_
<
kDL
GPU
,
int64_t
>(
CSRMatrix
*
csr
);
template
void
CSRSort_
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
*
csr
);
template
void
CSRSort_
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
*
csr
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/csr_sum.cu
View file @
aaaecbc9
...
...
@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
#ifdef USE_FP16
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int32_t
,
__half
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int32_t
,
__half
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int64_t
,
__half
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int64_t
,
__half
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
#endif
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int32_t
,
float
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int32_t
,
float
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int64_t
,
float
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int64_t
,
float
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int32_t
,
double
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int32_t
,
double
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
GPU
,
int64_t
,
double
>
(
template
std
::
pair
<
CSRMatrix
,
NDArray
>
CSRSum
<
kDL
ROCM
,
int64_t
,
double
>
(
const
std
::
vector
<
CSRMatrix
>&
,
const
std
::
vector
<
NDArray
>&
);
}
// namespace aten
...
...
src/array/cuda/csr_transpose.cc
View file @
aaaecbc9
...
...
@@ -20,7 +20,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
}
template
<
>
CSRMatrix
CSRTranspose
<
kDL
GPU
,
int32_t
>
(
CSRMatrix
csr
)
{
CSRMatrix
CSRTranspose
<
kDL
ROCM
,
int32_t
>
(
CSRMatrix
csr
)
{
auto
*
thr_entry
=
runtime
::
CUDAThreadEntry
::
ThreadLocal
();
hipStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
// allocate cusparse handle if needed
...
...
@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
}
template
<
>
CSRMatrix
CSRTranspose
<
kDL
GPU
,
int64_t
>
(
CSRMatrix
csr
)
{
CSRMatrix
CSRTranspose
<
kDL
ROCM
,
int64_t
>
(
CSRMatrix
csr
)
{
return
COOToCSR
(
COOTranspose
(
CSRToCOO
(
csr
,
false
)));
}
template
CSRMatrix
CSRTranspose
<
kDL
GPU
,
int32_t
>(
CSRMatrix
csr
);
template
CSRMatrix
CSRTranspose
<
kDL
GPU
,
int64_t
>(
CSRMatrix
csr
);
template
CSRMatrix
CSRTranspose
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
);
template
CSRMatrix
CSRTranspose
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
);
}
// namespace impl
}
// namespace aten
...
...
src/array/cuda/cuda_filter.cu
View file @
aaaecbc9
...
...
@@ -156,8 +156,8 @@ FilterRef CreateSetFilter(IdArray set) {
return
FilterRef
(
std
::
make_shared
<
CudaFilterSet
<
IdType
>>
(
set
));
}
template
FilterRef
CreateSetFilter
<
kDL
GPU
,
int32_t
>(
IdArray
set
);
template
FilterRef
CreateSetFilter
<
kDL
GPU
,
int64_t
>(
IdArray
set
);
template
FilterRef
CreateSetFilter
<
kDL
ROCM
,
int32_t
>(
IdArray
set
);
template
FilterRef
CreateSetFilter
<
kDL
ROCM
,
int64_t
>(
IdArray
set
);
}
// namespace array
}
// namespace dgl
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment