Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
aaaecbc9
Commit
aaaecbc9
authored
May 12, 2023
by
lisj
Browse files
处理kDLGPU为kDLROCM
parent
c454d419
Changes
54
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
147 additions
and
147 deletions
+147
-147
src/array/cuda/disjoint_union.cu
src/array/cuda/disjoint_union.cu
+2
-2
src/array/cuda/gather_mm.cu
src/array/cuda/gather_mm.cu
+24
-24
src/array/cuda/negative_sampling.cu
src/array/cuda/negative_sampling.cu
+2
-2
src/array/cuda/rowwise_sampling.cu
src/array/cuda/rowwise_sampling.cu
+2
-2
src/array/cuda/rowwise_sampling_prob.cu
src/array/cuda/rowwise_sampling_prob.cu
+4
-4
src/array/cuda/sddmm.cu
src/array/cuda/sddmm.cu
+12
-12
src/array/cuda/sddmm_hetero_coo.cu
src/array/cuda/sddmm_hetero_coo.cu
+6
-6
src/array/cuda/sddmm_hetero_csr.cu
src/array/cuda/sddmm_hetero_csr.cu
+6
-6
src/array/cuda/segment_reduce.cu
src/array/cuda/segment_reduce.cu
+24
-24
src/array/cuda/spmat_op_impl_coo.cu
src/array/cuda/spmat_op_impl_coo.cu
+4
-4
src/array/cuda/spmat_op_impl_csr.cu
src/array/cuda/spmat_op_impl_csr.cu
+22
-22
src/array/cuda/spmm.cu
src/array/cuda/spmm.cu
+12
-12
src/array/cuda/spmm_hetero.cu
src/array/cuda/spmm_hetero.cu
+6
-6
src/array/cuda/uvm/array_index_select_uvm.cu
src/array/cuda/uvm/array_index_select_uvm.cu
+3
-3
src/array/filter.cc
src/array/filter.cc
+2
-2
src/array/uvm_array.cc
src/array/uvm_array.cc
+3
-3
src/geometry/cuda/edge_coarsening_impl.cu
src/geometry/cuda/edge_coarsening_impl.cu
+6
-6
src/geometry/cuda/geometry_op_impl.cu
src/geometry/cuda/geometry_op_impl.cu
+4
-4
src/graph/heterograph.h
src/graph/heterograph.h
+1
-1
src/graph/sampling/randomwalks/get_node_types_gpu.cu
src/graph/sampling/randomwalks/get_node_types_gpu.cu
+2
-2
No files found.
src/array/cuda/disjoint_union.cu
View file @
aaaecbc9
...
@@ -177,8 +177,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
...
@@ -177,8 +177,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted
);
col_sorted
);
}
}
template
COOMatrix
DisjointUnionCoo
<
kDL
GPU
,
int32_t
>(
const
std
::
vector
<
COOMatrix
>&
coos
);
template
COOMatrix
DisjointUnionCoo
<
kDL
ROCM
,
int32_t
>(
const
std
::
vector
<
COOMatrix
>&
coos
);
template
COOMatrix
DisjointUnionCoo
<
kDL
GPU
,
int64_t
>(
const
std
::
vector
<
COOMatrix
>&
coos
);
template
COOMatrix
DisjointUnionCoo
<
kDL
ROCM
,
int64_t
>(
const
std
::
vector
<
COOMatrix
>&
coos
);
}
// namespace impl
}
// namespace impl
}
// namespace aten
}
// namespace aten
...
...
src/array/cuda/gather_mm.cu
View file @
aaaecbc9
...
@@ -395,74 +395,74 @@ void GatherMMScatter(const NDArray A,
...
@@ -395,74 +395,74 @@ void GatherMMScatter(const NDArray A,
}
}
template
void
GatherMM
<
kDL
GPU
,
int32_t
,
16
>(
template
void
GatherMM
<
kDL
ROCM
,
int32_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDL
GPU
,
int64_t
,
16
>(
template
void
GatherMM
<
kDL
ROCM
,
int64_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDL
GPU
,
int32_t
,
32
>(
template
void
GatherMM
<
kDL
ROCM
,
int32_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDL
GPU
,
int64_t
,
32
>(
template
void
GatherMM
<
kDL
ROCM
,
int64_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDL
GPU
,
int32_t
,
64
>(
template
void
GatherMM
<
kDL
ROCM
,
int32_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDL
GPU
,
int64_t
,
64
>(
template
void
GatherMM
<
kDL
ROCM
,
int64_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMMScatter
<
kDL
GPU
,
int32_t
,
16
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int32_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDL
GPU
,
int64_t
,
16
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int64_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDL
GPU
,
int32_t
,
32
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int32_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDL
GPU
,
int64_t
,
32
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int64_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDL
GPU
,
int32_t
,
64
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int32_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDL
GPU
,
int64_t
,
64
>(
template
void
GatherMMScatter
<
kDL
ROCM
,
int64_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
SegmentMM
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SegmentMM
<
kDL
ROCM
,
int32_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SegmentMM
<
kDL
ROCM
,
int64_t
,
16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SegmentMM
<
kDL
ROCM
,
int32_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SegmentMM
<
kDL
ROCM
,
int64_t
,
32
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SegmentMM
<
kDL
ROCM
,
int32_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SegmentMM
<
kDL
ROCM
,
int64_t
,
64
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int32_t
,
16
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int64_t
,
16
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int32_t
,
32
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int64_t
,
32
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int32_t
,
64
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SegmentMMBackwardB
<
kDL
ROCM
,
int64_t
,
64
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
}
// namespace aten
}
// namespace aten
...
...
src/array/cuda/negative_sampling.cu
View file @
aaaecbc9
...
@@ -212,9 +212,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
...
@@ -212,9 +212,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return
result
;
return
result
;
}
}
template
std
::
pair
<
IdArray
,
IdArray
>
CSRGlobalUniformNegativeSampling
<
kDL
GPU
,
int32_t
>
(
template
std
::
pair
<
IdArray
,
IdArray
>
CSRGlobalUniformNegativeSampling
<
kDL
ROCM
,
int32_t
>
(
const
CSRMatrix
&
,
int64_t
,
int
,
bool
,
bool
,
double
);
const
CSRMatrix
&
,
int64_t
,
int
,
bool
,
bool
,
double
);
template
std
::
pair
<
IdArray
,
IdArray
>
CSRGlobalUniformNegativeSampling
<
kDL
GPU
,
int64_t
>
(
template
std
::
pair
<
IdArray
,
IdArray
>
CSRGlobalUniformNegativeSampling
<
kDL
ROCM
,
int64_t
>
(
const
CSRMatrix
&
,
int64_t
,
int
,
bool
,
bool
,
double
);
const
CSRMatrix
&
,
int64_t
,
int
,
bool
,
bool
,
double
);
};
// namespace impl
};
// namespace impl
...
...
src/array/cuda/rowwise_sampling.cu
View file @
aaaecbc9
...
@@ -370,9 +370,9 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
...
@@ -370,9 +370,9 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
picked_col
,
picked_idx
);
picked_col
,
picked_idx
);
}
}
template
COOMatrix
CSRRowWiseSamplingUniform
<
kDL
GPU
,
int32_t
>(
template
COOMatrix
CSRRowWiseSamplingUniform
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
IdArray
,
int64_t
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
bool
);
template
COOMatrix
CSRRowWiseSamplingUniform
<
kDL
GPU
,
int64_t
>(
template
COOMatrix
CSRRowWiseSamplingUniform
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
IdArray
,
int64_t
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
bool
);
}
// namespace impl
}
// namespace impl
...
...
src/array/cuda/rowwise_sampling_prob.cu
View file @
aaaecbc9
...
@@ -652,13 +652,13 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
...
@@ -652,13 +652,13 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
picked_col
,
picked_idx
);
picked_col
,
picked_idx
);
}
}
template
COOMatrix
CSRRowWiseSampling
<
kDL
GPU
,
int32_t
,
float
>(
template
COOMatrix
CSRRowWiseSampling
<
kDL
ROCM
,
int32_t
,
float
>(
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
template
COOMatrix
CSRRowWiseSampling
<
kDL
GPU
,
int64_t
,
float
>(
template
COOMatrix
CSRRowWiseSampling
<
kDL
ROCM
,
int64_t
,
float
>(
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
template
COOMatrix
CSRRowWiseSampling
<
kDL
GPU
,
int32_t
,
double
>(
template
COOMatrix
CSRRowWiseSampling
<
kDL
ROCM
,
int32_t
,
double
>(
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
template
COOMatrix
CSRRowWiseSampling
<
kDL
GPU
,
int64_t
,
double
>(
template
COOMatrix
CSRRowWiseSampling
<
kDL
ROCM
,
int64_t
,
double
>(
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
CSRMatrix
,
IdArray
,
int64_t
,
FloatArray
,
bool
);
}
// namespace impl
}
// namespace impl
...
...
src/array/cuda/sddmm.cu
View file @
aaaecbc9
...
@@ -54,52 +54,52 @@ void SDDMMCoo(const std::string& op,
...
@@ -54,52 +54,52 @@ void SDDMMCoo(const std::string& op,
}
}
template
void
SDDMMCsr
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SDDMMCsr
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SDDMMCoo
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
int
lhs_target
,
int
rhs_target
);
...
...
src/array/cuda/sddmm_hetero_coo.cu
View file @
aaaecbc9
...
@@ -42,42 +42,42 @@ void SDDMMCooHetero(const std::string& op,
...
@@ -42,42 +42,42 @@ void SDDMMCooHetero(const std::string& op,
}
}
template
void
SDDMMCooHetero
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SDDMMCooHetero
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
...
...
src/array/cuda/sddmm_hetero_csr.cu
View file @
aaaecbc9
...
@@ -41,42 +41,42 @@ void SDDMMCsrHetero(const std::string& op,
...
@@ -41,42 +41,42 @@ void SDDMMCsrHetero(const std::string& op,
});
});
}
}
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SDDMMCsrHetero
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
...
...
src/array/cuda/segment_reduce.cu
View file @
aaaecbc9
...
@@ -73,113 +73,113 @@ void BackwardSegmentCmp(NDArray feat,
...
@@ -73,113 +73,113 @@ void BackwardSegmentCmp(NDArray feat,
}
}
template
void
SegmentReduce
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
SegmentReduce
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
SegmentReduce
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
SegmentReduce
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
SegmentReduce
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
SegmentReduce
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SegmentReduce
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
feat
,
NDArray
offsets
,
NDArray
offsets
,
NDArray
out
,
NDArray
out
,
NDArray
arg
);
NDArray
arg
);
template
void
ScatterAdd
<
kDL
GPU
,
int32_t
,
16
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int32_t
,
16
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
ScatterAdd
<
kDL
GPU
,
int64_t
,
16
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int64_t
,
16
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
ScatterAdd
<
kDL
GPU
,
int32_t
,
32
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int32_t
,
32
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
ScatterAdd
<
kDL
GPU
,
int64_t
,
32
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int64_t
,
32
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
ScatterAdd
<
kDL
GPU
,
int32_t
,
64
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int32_t
,
64
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
ScatterAdd
<
kDL
GPU
,
int64_t
,
64
>(
template
void
ScatterAdd
<
kDL
ROCM
,
int64_t
,
64
>(
NDArray
feat
,
NDArray
feat
,
NDArray
idx
,
NDArray
idx
,
NDArray
out
);
NDArray
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int32_t
,
16
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int32_t
,
16
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int64_t
,
16
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int64_t
,
16
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int32_t
,
32
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int32_t
,
32
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int64_t
,
32
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int64_t
,
32
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int32_t
,
64
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int32_t
,
64
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
UpdateGradMinMax_hetero
<
kDL
GPU
,
int64_t
,
64
>(
template
void
UpdateGradMinMax_hetero
<
kDL
ROCM
,
int64_t
,
64
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int32_t
,
16
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int32_t
,
16
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int64_t
,
16
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int64_t
,
16
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int32_t
,
32
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int32_t
,
32
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int64_t
,
32
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int64_t
,
32
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int32_t
,
64
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int32_t
,
64
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDL
GPU
,
int64_t
,
64
>(
template
void
BackwardSegmentCmp
<
kDL
ROCM
,
int64_t
,
64
>(
NDArray
feat
,
NDArray
feat
,
NDArray
arg
,
NDArray
arg
,
NDArray
out
);
NDArray
out
);
...
...
src/array/cuda/spmat_op_impl_coo.cu
View file @
aaaecbc9
...
@@ -89,8 +89,8 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
...
@@ -89,8 +89,8 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
return
*
rst
.
Ptr
<
IdType
>
();
return
*
rst
.
Ptr
<
IdType
>
();
}
}
template
int64_t
COOGetRowNNZ
<
kDL
GPU
,
int32_t
>(
COOMatrix
,
int64_t
);
template
int64_t
COOGetRowNNZ
<
kDL
ROCM
,
int32_t
>(
COOMatrix
,
int64_t
);
template
int64_t
COOGetRowNNZ
<
kDL
GPU
,
int64_t
>(
COOMatrix
,
int64_t
);
template
int64_t
COOGetRowNNZ
<
kDL
ROCM
,
int64_t
>(
COOMatrix
,
int64_t
);
template
<
typename
IdType
>
template
<
typename
IdType
>
__global__
void
_COOGetAllRowNNZKernel
(
__global__
void
_COOGetAllRowNNZKernel
(
...
@@ -137,8 +137,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
...
@@ -137,8 +137,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
}
}
}
}
template
NDArray
COOGetRowNNZ
<
kDL
GPU
,
int32_t
>(
COOMatrix
,
NDArray
);
template
NDArray
COOGetRowNNZ
<
kDL
ROCM
,
int32_t
>(
COOMatrix
,
NDArray
);
template
NDArray
COOGetRowNNZ
<
kDL
GPU
,
int64_t
>(
COOMatrix
,
NDArray
);
template
NDArray
COOGetRowNNZ
<
kDL
ROCM
,
int64_t
>(
COOMatrix
,
NDArray
);
}
// namespace impl
}
// namespace impl
}
// namespace aten
}
// namespace aten
...
...
src/array/cuda/spmat_op_impl_csr.cu
View file @
aaaecbc9
...
@@ -43,8 +43,8 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
...
@@ -43,8 +43,8 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
return
*
out
.
Ptr
<
IdType
>
()
!=
-
1
;
return
*
out
.
Ptr
<
IdType
>
()
!=
-
1
;
}
}
template
bool
CSRIsNonZero
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
bool
CSRIsNonZero
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
bool
CSRIsNonZero
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
bool
CSRIsNonZero
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
<
DLDeviceType
XPU
,
typename
IdType
>
template
<
DLDeviceType
XPU
,
typename
IdType
>
NDArray
CSRIsNonZero
(
CSRMatrix
csr
,
NDArray
row
,
NDArray
col
)
{
NDArray
CSRIsNonZero
(
CSRMatrix
csr
,
NDArray
row
,
NDArray
col
)
{
...
@@ -70,8 +70,8 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
...
@@ -70,8 +70,8 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return
rst
!=
-
1
;
return
rst
!=
-
1
;
}
}
template
NDArray
CSRIsNonZero
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
NDArray
,
NDArray
);
template
NDArray
CSRIsNonZero
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
NDArray
,
NDArray
);
template
NDArray
CSRIsNonZero
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
NDArray
,
NDArray
);
template
NDArray
CSRIsNonZero
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
NDArray
,
NDArray
);
///////////////////////////// CSRHasDuplicate /////////////////////////////
///////////////////////////// CSRHasDuplicate /////////////////////////////
...
@@ -117,8 +117,8 @@ bool CSRHasDuplicate(CSRMatrix csr) {
...
@@ -117,8 +117,8 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return
!
ret
;
return
!
ret
;
}
}
template
bool
CSRHasDuplicate
<
kDL
GPU
,
int32_t
>(
CSRMatrix
csr
);
template
bool
CSRHasDuplicate
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
);
template
bool
CSRHasDuplicate
<
kDL
GPU
,
int64_t
>(
CSRMatrix
csr
);
template
bool
CSRHasDuplicate
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
///////////////////////////// CSRGetRowNNZ /////////////////////////////
...
@@ -129,8 +129,8 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
...
@@ -129,8 +129,8 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
return
next
-
cur
;
return
next
-
cur
;
}
}
template
int64_t
CSRGetRowNNZ
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
int64_t
CSRGetRowNNZ
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
int64_t
CSRGetRowNNZ
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
int64_t
);
template
int64_t
CSRGetRowNNZ
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
int64_t
);
template
<
typename
IdType
>
template
<
typename
IdType
>
__global__
void
_CSRGetRowNNZKernel
(
__global__
void
_CSRGetRowNNZKernel
(
...
@@ -163,8 +163,8 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
...
@@ -163,8 +163,8 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return
rst
;
return
rst
;
}
}
template
NDArray
CSRGetRowNNZ
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
NDArray
);
template
NDArray
CSRGetRowNNZ
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
NDArray
);
template
NDArray
CSRGetRowNNZ
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
NDArray
);
template
NDArray
CSRGetRowNNZ
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
NDArray
);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
...
@@ -175,8 +175,8 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
...
@@ -175,8 +175,8 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
return
csr
.
indices
.
CreateView
({
len
},
csr
.
indices
->
dtype
,
offset
);
return
csr
.
indices
.
CreateView
({
len
},
csr
.
indices
->
dtype
,
offset
);
}
}
template
NDArray
CSRGetRowColumnIndices
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowColumnIndices
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowColumnIndices
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowColumnIndices
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
int64_t
);
///////////////////////////// CSRGetRowData /////////////////////////////
///////////////////////////// CSRGetRowData /////////////////////////////
...
@@ -190,8 +190,8 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
...
@@ -190,8 +190,8 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return
aten
::
Range
(
offset
,
offset
+
len
,
csr
.
indptr
->
dtype
.
bits
,
csr
.
indptr
->
ctx
);
return
aten
::
Range
(
offset
,
offset
+
len
,
csr
.
indptr
->
dtype
.
bits
,
csr
.
indptr
->
ctx
);
}
}
template
NDArray
CSRGetRowData
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowData
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowData
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
int64_t
);
template
NDArray
CSRGetRowData
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
int64_t
);
///////////////////////////// CSRSliceRows /////////////////////////////
///////////////////////////// CSRSliceRows /////////////////////////////
...
@@ -216,8 +216,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
...
@@ -216,8 +216,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr
.
sorted
);
csr
.
sorted
);
}
}
template
CSRMatrix
CSRSliceRows
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
CSRMatrix
CSRSliceRows
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
CSRMatrix
CSRSliceRows
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
template
CSRMatrix
CSRSliceRows
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
int64_t
,
int64_t
);
/*!
/*!
* \brief Copy data segment to output buffers
* \brief Copy data segment to output buffers
...
@@ -273,8 +273,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
...
@@ -273,8 +273,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
csr
.
sorted
);
csr
.
sorted
);
}
}
template
CSRMatrix
CSRSliceRows
<
kDL
GPU
,
int32_t
>(
CSRMatrix
,
NDArray
);
template
CSRMatrix
CSRSliceRows
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
,
NDArray
);
template
CSRMatrix
CSRSliceRows
<
kDL
GPU
,
int64_t
>(
CSRMatrix
,
NDArray
);
template
CSRMatrix
CSRSliceRows
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
,
NDArray
);
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
...
@@ -393,9 +393,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
...
@@ -393,9 +393,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
return
{
ret_row
,
ret_col
,
ret_data
};
return
{
ret_row
,
ret_col
,
ret_data
};
}
}
template
std
::
vector
<
NDArray
>
CSRGetDataAndIndices
<
kDL
GPU
,
int32_t
>
(
template
std
::
vector
<
NDArray
>
CSRGetDataAndIndices
<
kDL
ROCM
,
int32_t
>
(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
);
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
);
template
std
::
vector
<
NDArray
>
CSRGetDataAndIndices
<
kDL
GPU
,
int64_t
>
(
template
std
::
vector
<
NDArray
>
CSRGetDataAndIndices
<
kDL
ROCM
,
int64_t
>
(
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
);
CSRMatrix
csr
,
NDArray
rows
,
NDArray
cols
);
///////////////////////////// CSRSliceMatrix /////////////////////////////
///////////////////////////// CSRSliceMatrix /////////////////////////////
...
@@ -502,9 +502,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
...
@@ -502,9 +502,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
ret_col
,
ret_data
);
ret_col
,
ret_data
);
}
}
template
CSRMatrix
CSRSliceMatrix
<
kDL
GPU
,
int32_t
>(
template
CSRMatrix
CSRSliceMatrix
<
kDL
ROCM
,
int32_t
>(
CSRMatrix
csr
,
runtime
::
NDArray
rows
,
runtime
::
NDArray
cols
);
CSRMatrix
csr
,
runtime
::
NDArray
rows
,
runtime
::
NDArray
cols
);
template
CSRMatrix
CSRSliceMatrix
<
kDL
GPU
,
int64_t
>(
template
CSRMatrix
CSRSliceMatrix
<
kDL
ROCM
,
int64_t
>(
CSRMatrix
csr
,
runtime
::
NDArray
rows
,
runtime
::
NDArray
cols
);
CSRMatrix
csr
,
runtime
::
NDArray
rows
,
runtime
::
NDArray
cols
);
}
// namespace impl
}
// namespace impl
...
...
src/array/cuda/spmm.cu
View file @
aaaecbc9
...
@@ -147,53 +147,53 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
...
@@ -147,53 +147,53 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
}
}
}
template
void
SpMMCsr
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SpMMCsr
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SpMMCoo
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
...
...
src/array/cuda/spmm_hetero.cu
View file @
aaaecbc9
...
@@ -201,37 +201,37 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
...
@@ -201,37 +201,37 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
});
});
}
}
template
void
SpMMCsrHetero
<
kDL
GPU
,
int32_t
,
16
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int32_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
template
void
SpMMCsrHetero
<
kDL
GPU
,
int64_t
,
16
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int64_t
,
16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
template
void
SpMMCsrHetero
<
kDL
GPU
,
int32_t
,
32
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int32_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
template
void
SpMMCsrHetero
<
kDL
GPU
,
int64_t
,
32
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int64_t
,
32
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
template
void
SpMMCsrHetero
<
kDL
GPU
,
int32_t
,
64
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int32_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
const
std
::
vector
<
dgl_type_t
>&
ufeat_ntids
,
const
std
::
vector
<
dgl_type_t
>&
out_ntids
);
template
void
SpMMCsrHetero
<
kDL
GPU
,
int64_t
,
64
>(
template
void
SpMMCsrHetero
<
kDL
ROCM
,
int64_t
,
64
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
...
...
src/array/cuda/uvm/array_index_select_uvm.cu
View file @
aaaecbc9
...
@@ -25,7 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
...
@@ -25,7 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
std
::
vector
<
int64_t
>
shape
{
len
};
std
::
vector
<
int64_t
>
shape
{
len
};
CHECK
(
array
.
IsPinned
());
CHECK
(
array
.
IsPinned
());
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
GPU
);
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
ROCM
);
for
(
int
d
=
1
;
d
<
array
->
ndim
;
++
d
)
{
for
(
int
d
=
1
;
d
<
array
->
ndim
;
++
d
)
{
num_feat
*=
array
->
shape
[
d
];
num_feat
*=
array
->
shape
[
d
];
...
@@ -85,8 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
...
@@ -85,8 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
std
::
vector
<
int64_t
>
shape
{
len
};
std
::
vector
<
int64_t
>
shape
{
len
};
CHECK
(
dest
.
IsPinned
());
CHECK
(
dest
.
IsPinned
());
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
GPU
);
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
ROCM
);
CHECK_EQ
(
source
->
ctx
.
device_type
,
kDL
GPU
);
CHECK_EQ
(
source
->
ctx
.
device_type
,
kDL
ROCM
);
for
(
int
d
=
1
;
d
<
source
->
ndim
;
++
d
)
{
for
(
int
d
=
1
;
d
<
source
->
ndim
;
++
d
)
{
num_feat
*=
source
->
shape
[
d
];
num_feat
*=
source
->
shape
[
d
];
...
...
src/array/filter.cc
View file @
aaaecbc9
...
@@ -23,10 +23,10 @@ DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
...
@@ -23,10 +23,10 @@ DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
IdArray
array
=
args
[
0
];
IdArray
array
=
args
[
0
];
auto
ctx
=
array
->
ctx
;
auto
ctx
=
array
->
ctx
;
// TODO(nv-dlasalle): Implement CPU version.
// TODO(nv-dlasalle): Implement CPU version.
if
(
ctx
.
device_type
==
kDL
GPU
)
{
if
(
ctx
.
device_type
==
kDL
ROCM
)
{
#ifdef DGL_USE_CUDA
#ifdef DGL_USE_CUDA
ATEN_ID_TYPE_SWITCH
(
array
->
dtype
,
IdType
,
{
ATEN_ID_TYPE_SWITCH
(
array
->
dtype
,
IdType
,
{
*
rv
=
CreateSetFilter
<
kDL
GPU
,
IdType
>
(
array
);
*
rv
=
CreateSetFilter
<
kDL
ROCM
,
IdType
>
(
array
);
});
});
#else
#else
LOG
(
FATAL
)
<<
"GPU support not compiled."
;
LOG
(
FATAL
)
<<
"GPU support not compiled."
;
...
...
src/array/uvm_array.cc
View file @
aaaecbc9
...
@@ -16,7 +16,7 @@ namespace aten {
...
@@ -16,7 +16,7 @@ namespace aten {
NDArray
IndexSelectCPUFromGPU
(
NDArray
array
,
IdArray
index
)
{
NDArray
IndexSelectCPUFromGPU
(
NDArray
array
,
IdArray
index
)
{
#ifdef DGL_USE_CUDA
#ifdef DGL_USE_CUDA
CHECK
(
array
.
IsPinned
())
<<
"Input array must be in pinned memory."
;
CHECK
(
array
.
IsPinned
())
<<
"Input array must be in pinned memory."
;
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
GPU
)
<<
"Index must be on the GPU."
;
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
ROCM
)
<<
"Index must be on the GPU."
;
CHECK_GE
(
array
->
ndim
,
1
)
<<
"Input array must have at least 1 dimension."
;
CHECK_GE
(
array
->
ndim
,
1
)
<<
"Input array must have at least 1 dimension."
;
CHECK_EQ
(
index
->
ndim
,
1
)
<<
"Index must be a 1D array."
;
CHECK_EQ
(
index
->
ndim
,
1
)
<<
"Index must be a 1D array."
;
...
@@ -34,8 +34,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
...
@@ -34,8 +34,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
void
IndexScatterGPUToCPU
(
NDArray
dest
,
IdArray
index
,
NDArray
source
)
{
void
IndexScatterGPUToCPU
(
NDArray
dest
,
IdArray
index
,
NDArray
source
)
{
#ifdef DGL_USE_CUDA
#ifdef DGL_USE_CUDA
CHECK
(
dest
.
IsPinned
())
<<
"Destination array must be in pinned memory."
;
CHECK
(
dest
.
IsPinned
())
<<
"Destination array must be in pinned memory."
;
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
GPU
)
<<
"Index must be on the GPU."
;
CHECK_EQ
(
index
->
ctx
.
device_type
,
kDL
ROCM
)
<<
"Index must be on the GPU."
;
CHECK_EQ
(
source
->
ctx
.
device_type
,
kDL
GPU
)
<<
"Source array must be on the GPU."
;
CHECK_EQ
(
source
->
ctx
.
device_type
,
kDL
ROCM
)
<<
"Source array must be on the GPU."
;
CHECK_EQ
(
dest
->
dtype
,
source
->
dtype
)
<<
"Destination array and source "
CHECK_EQ
(
dest
->
dtype
,
source
->
dtype
)
<<
"Destination array and source "
"array must have the same dtype."
;
"array must have the same dtype."
;
CHECK_GE
(
dest
->
ndim
,
1
)
<<
"Destination array must have at least 1 dimension."
;
CHECK_GE
(
dest
->
ndim
,
1
)
<<
"Destination array must have at least 1 dimension."
;
...
...
src/geometry/cuda/edge_coarsening_impl.cu
View file @
aaaecbc9
...
@@ -183,13 +183,13 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
...
@@ -183,13 +183,13 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
}
}
device
->
FreeWorkspace
(
ctx
,
prop
);
device
->
FreeWorkspace
(
ctx
,
prop
);
}
}
template
void
WeightedNeighborMatching
<
kDL
GPU
,
float
,
int32_t
>(
template
void
WeightedNeighborMatching
<
kDL
ROCM
,
float
,
int32_t
>(
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
template
void
WeightedNeighborMatching
<
kDL
GPU
,
float
,
int64_t
>(
template
void
WeightedNeighborMatching
<
kDL
ROCM
,
float
,
int64_t
>(
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
template
void
WeightedNeighborMatching
<
kDL
GPU
,
double
,
int32_t
>(
template
void
WeightedNeighborMatching
<
kDL
ROCM
,
double
,
int32_t
>(
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
template
void
WeightedNeighborMatching
<
kDL
GPU
,
double
,
int64_t
>(
template
void
WeightedNeighborMatching
<
kDL
ROCM
,
double
,
int64_t
>(
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
const
aten
::
CSRMatrix
&
csr
,
const
NDArray
weight
,
IdArray
result
);
/*! \brief Unweighted neighbor matching procedure (GPU version).
/*! \brief Unweighted neighbor matching procedure (GPU version).
...
@@ -222,8 +222,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
...
@@ -222,8 +222,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
WeightedNeighborMatching
<
XPU
,
float
,
IdType
>
(
csr
,
weight
,
result
);
WeightedNeighborMatching
<
XPU
,
float
,
IdType
>
(
csr
,
weight
,
result
);
}
}
template
void
NeighborMatching
<
kDL
GPU
,
int32_t
>(
const
aten
::
CSRMatrix
&
csr
,
IdArray
result
);
template
void
NeighborMatching
<
kDL
ROCM
,
int32_t
>(
const
aten
::
CSRMatrix
&
csr
,
IdArray
result
);
template
void
NeighborMatching
<
kDL
GPU
,
int64_t
>(
const
aten
::
CSRMatrix
&
csr
,
IdArray
result
);
template
void
NeighborMatching
<
kDL
ROCM
,
int64_t
>(
const
aten
::
CSRMatrix
&
csr
,
IdArray
result
);
}
// namespace impl
}
// namespace impl
}
// namespace geometry
}
// namespace geometry
...
...
src/geometry/cuda/geometry_op_impl.cu
View file @
aaaecbc9
...
@@ -116,16 +116,16 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
...
@@ -116,16 +116,16 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
point_in_batch
,
dim
,
start_idx_data
,
dist_data
,
ret_data
);
point_in_batch
,
dim
,
start_idx_data
,
dist_data
,
ret_data
);
}
}
template
void
FarthestPointSampler
<
kDL
GPU
,
float
,
int32_t
>(
template
void
FarthestPointSampler
<
kDL
ROCM
,
float
,
int32_t
>(
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
template
void
FarthestPointSampler
<
kDL
GPU
,
float
,
int64_t
>(
template
void
FarthestPointSampler
<
kDL
ROCM
,
float
,
int64_t
>(
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
template
void
FarthestPointSampler
<
kDL
GPU
,
double
,
int32_t
>(
template
void
FarthestPointSampler
<
kDL
ROCM
,
double
,
int32_t
>(
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
template
void
FarthestPointSampler
<
kDL
GPU
,
double
,
int64_t
>(
template
void
FarthestPointSampler
<
kDL
ROCM
,
double
,
int64_t
>(
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
array
,
int64_t
batch_size
,
int64_t
sample_points
,
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
NDArray
dist
,
IdArray
start_idx
,
IdArray
result
);
...
...
src/graph/heterograph.h
View file @
aaaecbc9
...
@@ -237,7 +237,7 @@ class HeteroGraph : public BaseHeteroGraph {
...
@@ -237,7 +237,7 @@ class HeteroGraph : public BaseHeteroGraph {
* \note The graph will be pinned inplace. Behavior depends on the current context,
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPU: will be pinned;
* IsPinned: directly return;
* IsPinned: directly return;
* kDL
GPU
: invalid, will throw an error.
* kDL
ROCM
: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
* The context check is deferred to pinning the NDArray.
*/
*/
void
PinMemory_
()
override
;
void
PinMemory_
()
override
;
...
...
src/graph/sampling/randomwalks/get_node_types_gpu.cu
View file @
aaaecbc9
...
@@ -61,11 +61,11 @@ TypeArray GetNodeTypesFromMetapath(
...
@@ -61,11 +61,11 @@ TypeArray GetNodeTypesFromMetapath(
}
}
template
template
TypeArray
GetNodeTypesFromMetapath
<
kDL
GPU
,
int32_t
>(
TypeArray
GetNodeTypesFromMetapath
<
kDL
ROCM
,
int32_t
>(
const
HeteroGraphPtr
hg
,
const
HeteroGraphPtr
hg
,
const
TypeArray
metapath
);
const
TypeArray
metapath
);
template
template
TypeArray
GetNodeTypesFromMetapath
<
kDL
GPU
,
int64_t
>(
TypeArray
GetNodeTypesFromMetapath
<
kDL
ROCM
,
int64_t
>(
const
HeteroGraphPtr
hg
,
const
HeteroGraphPtr
hg
,
const
TypeArray
metapath
);
const
TypeArray
metapath
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment