Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
83115794
Unverified
Commit
83115794
authored
Jul 14, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Jul 14, 2023
Browse files
[Performance][CUDA] Faster CSRToCOO (#5648)
Co-authored-by:
Hongzhi (Steve), Chen
<
chenhongzhi.nkcs@gmail.com
>
parent
dc06060b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
33 deletions
+51
-33
CMakeLists.txt
CMakeLists.txt
+1
-0
src/array/cuda/csr2coo.cu
src/array/cuda/csr2coo.cu
+49
-32
third_party/thrust
third_party/thrust
+1
-1
No files found.
CMakeLists.txt
View file @
83115794
...
@@ -50,6 +50,7 @@ if(USE_CUDA)
...
@@ -50,6 +50,7 @@ if(USE_CUDA)
message
(
STATUS
"Use external CUB/Thrust library for a consistent API and performance."
)
message
(
STATUS
"Use external CUB/Thrust library for a consistent API and performance."
)
cuda_include_directories
(
BEFORE
"
${
CMAKE_SOURCE_DIR
}
/third_party/thrust"
)
cuda_include_directories
(
BEFORE
"
${
CMAKE_SOURCE_DIR
}
/third_party/thrust"
)
cuda_include_directories
(
BEFORE
"
${
CMAKE_SOURCE_DIR
}
/third_party/thrust/dependencies/cub"
)
cuda_include_directories
(
BEFORE
"
${
CMAKE_SOURCE_DIR
}
/third_party/thrust/dependencies/cub"
)
cuda_include_directories
(
BEFORE
"
${
CMAKE_SOURCE_DIR
}
/third_party/thrust/dependencies/libcudacxx/include"
)
endif
(
USE_CUDA
)
endif
(
USE_CUDA
)
# initial variables
# initial variables
...
...
src/array/cuda/csr2coo.cu
View file @
83115794
...
@@ -4,8 +4,12 @@
...
@@ -4,8 +4,12 @@
* @brief CSR2COO
* @brief CSR2COO
*/
*/
#include <dgl/array.h>
#include <dgl/array.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
#include "./utils.h"
namespace
dgl
{
namespace
dgl
{
...
@@ -45,33 +49,27 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
...
@@ -45,33 +49,27 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
csr
.
num_rows
,
csr
.
num_cols
,
row
,
indices
,
data
,
true
,
csr
.
sorted
);
csr
.
num_rows
,
csr
.
num_cols
,
row
,
indices
,
data
,
true
,
csr
.
sorted
);
}
}
/**
struct
RepeatIndex
{
* @brief Repeat elements
template
<
typename
IdType
>
* @param val Value to repeat
__host__
__device__
auto
operator
()(
IdType
i
)
{
* @param repeats Number of repeats for each value
return
thrust
::
make_constant_iterator
(
i
);
* @param pos The position of the output buffer to write the value.
* @param out Output buffer.
* @param length Number of values
*
* For example:
* val = [3, 0, 1]
* repeats = [1, 0, 2]
* pos = [0, 1, 1] # write to output buffer position 0, 1, 1
* then,
* out = [3, 1, 1]
*/
template
<
typename
DType
,
typename
IdType
>
__global__
void
_RepeatKernel
(
const
DType
*
val
,
const
IdType
*
pos
,
DType
*
out
,
int64_t
n_row
,
int64_t
length
)
{
IdType
tx
=
static_cast
<
IdType
>
(
blockIdx
.
x
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride_x
=
gridDim
.
x
*
blockDim
.
x
;
while
(
tx
<
length
)
{
IdType
i
=
dgl
::
cuda
::
_UpperBound
(
pos
,
n_row
,
tx
)
-
1
;
out
[
tx
]
=
val
[
i
];
tx
+=
stride_x
;
}
}
}
};
template
<
typename
IdType
>
struct
OutputBufferIndexer
{
const
IdType
*
indptr
;
IdType
*
buffer
;
__host__
__device__
auto
operator
()(
IdType
i
)
{
return
buffer
+
indptr
[
i
];
}
};
template
<
typename
IdType
>
struct
AdjacentDifference
{
const
IdType
*
indptr
;
__host__
__device__
auto
operator
()(
IdType
i
)
{
return
indptr
[
i
+
1
]
-
indptr
[
i
];
}
};
template
<
>
template
<
>
COOMatrix
CSRToCOO
<
kDGLCUDA
,
int64_t
>
(
CSRMatrix
csr
)
{
COOMatrix
CSRToCOO
<
kDGLCUDA
,
int64_t
>
(
CSRMatrix
csr
)
{
...
@@ -80,14 +78,33 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
...
@@ -80,14 +78,33 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const
int64_t
nnz
=
csr
.
indices
->
shape
[
0
];
const
int64_t
nnz
=
csr
.
indices
->
shape
[
0
];
const
auto
nbits
=
csr
.
indptr
->
dtype
.
bits
;
const
auto
nbits
=
csr
.
indptr
->
dtype
.
bits
;
IdArray
rowids
=
Range
(
0
,
csr
.
num_rows
,
nbits
,
ctx
);
IdArray
ret_row
=
NewIdArray
(
nnz
,
ctx
,
nbits
);
IdArray
ret_row
=
NewIdArray
(
nnz
,
ctx
,
nbits
);
const
int
nt
=
256
;
runtime
::
CUDAWorkspaceAllocator
allocator
(
csr
.
indptr
->
ctx
);
const
int
nb
=
(
nnz
+
nt
-
1
)
/
nt
;
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
CUDA_KERNEL_CALL
(
_RepeatKernel
,
nb
,
nt
,
0
,
stream
,
rowids
.
Ptr
<
int64_t
>
(),
auto
input_buffer
=
thrust
::
make_transform_iterator
(
iota
,
RepeatIndex
{});
csr
.
indptr
.
Ptr
<
int64_t
>
(),
ret_row
.
Ptr
<
int64_t
>
(),
csr
.
num_rows
,
nnz
);
auto
output_buffer
=
thrust
::
make_transform_iterator
(
iota
,
OutputBufferIndexer
<
int64_t
>
{
csr
.
indptr
.
Ptr
<
int64_t
>
(),
ret_row
.
Ptr
<
int64_t
>
()});
auto
buffer_sizes
=
thrust
::
make_transform_iterator
(
iota
,
AdjacentDifference
<
int64_t
>
{
csr
.
indptr
.
Ptr
<
int64_t
>
()});
constexpr
int64_t
max_copy_at_once
=
std
::
numeric_limits
<
int32_t
>::
max
();
for
(
int64_t
i
=
0
;
i
<
csr
.
num_rows
;
i
+=
max_copy_at_once
)
{
std
::
size_t
temp_storage_bytes
=
0
;
CUDA_CALL
(
cub
::
DeviceCopy
::
Batched
(
nullptr
,
temp_storage_bytes
,
input_buffer
+
i
,
output_buffer
+
i
,
buffer_sizes
+
i
,
std
::
min
(
csr
.
num_rows
-
i
,
max_copy_at_once
),
stream
));
auto
temp
=
allocator
.
alloc_unique
<
char
>
(
temp_storage_bytes
);
CUDA_CALL
(
cub
::
DeviceCopy
::
Batched
(
temp
.
get
(),
temp_storage_bytes
,
input_buffer
+
i
,
output_buffer
+
i
,
buffer_sizes
+
i
,
std
::
min
(
csr
.
num_rows
-
i
,
max_copy_at_once
),
stream
));
}
return
COOMatrix
(
return
COOMatrix
(
csr
.
num_rows
,
csr
.
num_cols
,
ret_row
,
csr
.
indices
,
csr
.
data
,
true
,
csr
.
num_rows
,
csr
.
num_cols
,
ret_row
,
csr
.
indices
,
csr
.
data
,
true
,
...
...
thrust
@
02931a30
Compare
6a3078c6
...
02931a30
Subproject commit
6a3078c64cab0e2f276340fa5dcafa0d758ed890
Subproject commit
02931a309bee769853088b79b4e3ab1c0bd2336c
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment