Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c59000ac
Unverified
Commit
c59000ac
authored
Nov 24, 2022
by
Xin Yao
Committed by
GitHub
Nov 24, 2022
Browse files
[Cleanup] Remove duplicated _IndexSelect (#4874)
parent
0cb5f0fd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2 additions
and
51 deletions
+2
-51
src/array/cuda/spmm.cu
src/array/cuda/spmm.cu
+1
-1
src/array/cuda/spmm.cuh
src/array/cuda/spmm.cuh
+0
-49
src/array/cuda/spmm_hetero.cu
src/array/cuda/spmm_hetero.cu
+1
-1
No files found.
src/array/cuda/spmm.cu
View file @
c59000ac
...
@@ -45,7 +45,7 @@ void SpMMCsr(
...
@@ -45,7 +45,7 @@ void SpMMCsr(
int64_t
x_length
=
1
;
int64_t
x_length
=
1
;
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
x_length
*=
ufeat
->
shape
[
i
];
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
x_length
*=
ufeat
->
shape
[
i
];
if
(
!
IsNullArray
(
csr
.
data
))
{
if
(
!
IsNullArray
(
csr
.
data
))
{
efeat
=
_
IndexSelect
<
DType
,
IdType
>
(
efeat
,
csr
.
data
);
efeat
=
IndexSelect
(
efeat
,
csr
.
data
);
}
}
CusparseCsrmm2
<
DType
,
IdType
>
(
CusparseCsrmm2
<
DType
,
IdType
>
(
ufeat
->
ctx
,
csr
,
static_cast
<
DType
*>
(
ufeat
->
data
),
ufeat
->
ctx
,
csr
,
static_cast
<
DType
*>
(
ufeat
->
data
),
...
...
src/array/cuda/spmm.cuh
View file @
c59000ac
...
@@ -98,19 +98,6 @@ cublasStatus_t Xgeam<double>(
...
@@ -98,19 +98,6 @@ cublasStatus_t Xgeam<double>(
handle
,
transa
,
transb
,
m
,
n
,
alpha
,
A
,
lda
,
beta
,
B
,
ldb
,
C
,
ldc
);
handle
,
transa
,
transb
,
m
,
n
,
alpha
,
A
,
lda
,
beta
,
B
,
ldb
,
C
,
ldc
);
}
}
/**
* @brief IndexSelect operator kernel implementation.
* @note duplicate of IndexSelectKernel defined in array_index_select.cu
*/
template
<
typename
DType
,
typename
IdType
>
__global__
void
_IndexSelectKernel
(
const
DType
*
__restrict__
in
,
const
IdType
*
__restrict__
idx
,
DType
*
__restrict__
out
,
int
n
,
int
m
)
{
int
i
=
blockIdx
.
x
;
for
(
int
j
=
threadIdx
.
x
;
j
<
m
;
j
+=
blockDim
.
x
)
out
[
i
*
m
+
j
]
=
in
[
idx
[
i
]
*
m
+
j
];
}
/**
/**
* @brief Transpose operator kernel implementation.
* @brief Transpose operator kernel implementation.
* @note not efficient but it's not a bottleneck, used for float16 dtype.
* @note not efficient but it's not a bottleneck, used for float16 dtype.
...
@@ -168,42 +155,6 @@ void _Transpose<__nv_bfloat16>(
...
@@ -168,42 +155,6 @@ void _Transpose<__nv_bfloat16>(
}
}
#endif // BF16_ENABLED
#endif // BF16_ENABLED
/**
* @brief
*/
template
<
typename
DType
,
typename
IdType
>
__global__
void
_IndexSelectKernel
(
const
DType
*
array
,
const
IdType
*
index
,
int64_t
length
,
DType
*
out
)
{
int
tx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride_x
=
gridDim
.
x
*
blockDim
.
x
;
while
(
tx
<
length
)
{
out
[
tx
]
=
array
[
index
[
tx
]];
tx
+=
stride_x
;
}
}
/* @brief IndexSelect operator.
* @note duplicate of IndexSelect defined in array_op.h but it can
* not be applied to float16 dtype.
*/
template
<
typename
DType
,
typename
IdType
>
NDArray
_IndexSelect
(
NDArray
array
,
NDArray
index
)
{
cudaStream_t
stream
=
runtime
::
getCurrentCUDAStream
();
const
DType
*
array_data
=
static_cast
<
DType
*>
(
array
->
data
);
const
IdType
*
idx_data
=
static_cast
<
IdType
*>
(
index
->
data
);
const
int64_t
arr_len
=
array
->
shape
[
0
];
const
int64_t
len
=
index
->
shape
[
0
];
NDArray
ret
=
NDArray
::
Empty
({
len
},
array
->
dtype
,
array
->
ctx
);
if
(
len
==
0
)
return
ret
;
DType
*
ret_data
=
static_cast
<
DType
*>
(
ret
->
data
);
const
int
nt
=
FindNumThreads
(
len
);
const
int
nb
=
(
len
+
nt
-
1
)
/
nt
;
CUDA_KERNEL_CALL
(
_IndexSelectKernel
,
nb
,
nt
,
0
,
stream
,
array_data
,
idx_data
,
len
,
ret_data
);
return
ret
;
}
#if CUDART_VERSION < 11000
#if CUDART_VERSION < 11000
template
<
typename
DType
>
template
<
typename
DType
>
cusparseStatus_t
Xcsrmm2
(
cusparseStatus_t
Xcsrmm2
(
...
...
src/array/cuda/spmm_hetero.cu
View file @
c59000ac
...
@@ -134,7 +134,7 @@ void SpMMCsrHetero(
...
@@ -134,7 +134,7 @@ void SpMMCsrHetero(
cusparse_available
<
DType
,
IdType
>
(
more_nnz
))
{
// cusparse
cusparse_available
<
DType
,
IdType
>
(
more_nnz
))
{
// cusparse
NDArray
efeat
=
vec_efeat
[
etype
];
NDArray
efeat
=
vec_efeat
[
etype
];
if
(
!
IsNullArray
(
csr
.
data
))
if
(
!
IsNullArray
(
csr
.
data
))
efeat
=
_
IndexSelect
<
DType
,
IdType
>
(
efeat
,
csr
.
data
);
efeat
=
IndexSelect
(
efeat
,
csr
.
data
);
CusparseCsrmm2Hetero
<
DType
,
IdType
>
(
CusparseCsrmm2Hetero
<
DType
,
IdType
>
(
csr
.
indptr
->
ctx
,
csr
,
static_cast
<
DType
*>
(
vec_ufeat
[
src_id
]
->
data
),
csr
.
indptr
->
ctx
,
csr
,
static_cast
<
DType
*>
(
vec_ufeat
[
src_id
]
->
data
),
static_cast
<
DType
*>
(
efeat
->
data
),
static_cast
<
DType
*>
(
efeat
->
data
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment