Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d873b09a
Unverified
Commit
d873b09a
authored
Nov 29, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Nov 29, 2023
Browse files
[Graphbolt][CUDA] Optimize UVA by using CUB sort (#6635)
parent
a67d9e6f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
2 deletions
+116
-2
graphbolt/src/cuda/common.h
graphbolt/src/cuda/common.h
+60
-0
graphbolt/src/cuda/index_select_impl.cu
graphbolt/src/cuda/index_select_impl.cu
+35
-2
graphbolt/src/cuda/utils.h
graphbolt/src/cuda/utils.h
+21
-0
No files found.
graphbolt/src/cuda/common.h
View file @
d873b09a
...
...
@@ -8,10 +8,68 @@
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
#include <torch/script.h>
#include <memory>
#include <unordered_map>
namespace
graphbolt
{
namespace
cuda
{
/**
* @brief This class is designed to allocate workspace storage
* and to get a nonblocking thrust execution policy
* that uses torch's CUDA memory pool and the current cuda stream:
* cuda::CUDAWorkspaceAllocator allocator;
* const auto stream = torch::cuda::getDefaultCUDAStream();
* const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
* Now, one can pass exec_policy to thrust functions
* To get an integer array of size 1000 whose lifetime is managed by unique_ptr,
* use:
* auto int_array = allocator.AllocateStorage<int>(1000);
* int_array.get() gives the raw pointer.
*/
class
CUDAWorkspaceAllocator
{
using
TensorPtrMapType
=
std
::
unordered_map
<
void
*
,
torch
::
Tensor
>
;
std
::
shared_ptr
<
TensorPtrMapType
>
ptr_map_
;
public:
// Required by thrust to satisfy allocator requirements.
using
value_type
=
char
;
explicit
CUDAWorkspaceAllocator
()
:
ptr_map_
(
std
::
make_shared
<
TensorPtrMapType
>
())
{}
CUDAWorkspaceAllocator
&
operator
=
(
const
CUDAWorkspaceAllocator
&
)
=
default
;
void
operator
()(
void
*
ptr
)
const
{
ptr_map_
->
erase
(
ptr
);
}
// Required by thrust to satisfy allocator requirements.
value_type
*
allocate
(
std
::
ptrdiff_t
size
)
const
{
auto
tensor
=
torch
::
empty
(
size
,
torch
::
TensorOptions
()
.
dtype
(
torch
::
kByte
)
.
device
(
c10
::
DeviceType
::
CUDA
));
ptr_map_
->
operator
[](
tensor
.
data_ptr
())
=
tensor
;
return
reinterpret_cast
<
value_type
*>
(
tensor
.
data_ptr
());
}
// Required by thrust to satisfy allocator requirements.
void
deallocate
(
value_type
*
ptr
,
std
::
size_t
)
const
{
operator
()(
ptr
);
}
template
<
typename
T
>
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
AllocateStorage
(
std
::
size_t
size
)
const
{
return
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
(
reinterpret_cast
<
T
*>
(
allocate
(
sizeof
(
T
)
*
size
)),
*
this
);
}
};
template
<
typename
T
>
inline
bool
is_zero
(
T
size
)
{
return
size
==
0
;
...
...
@@ -22,6 +80,8 @@ inline bool is_zero<dim3>(dim3 size) {
return
size
.
x
==
0
||
size
.
y
==
0
||
size
.
z
==
0
;
}
#define CUDA_CALL(func) C10_CUDA_CHECK((func))
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \
if (!graphbolt::cuda::is_zero((nblks)) && \
...
...
graphbolt/src/cuda/index_select_impl.cu
View file @
d873b09a
...
...
@@ -5,17 +5,49 @@
*/
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <numeric>
#include "../index_select.h"
#include "./common.h"
#include "./utils.h"
#include "cub/cub.cuh"
namespace
graphbolt
{
namespace
ops
{
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
Sort
(
torch
::
Tensor
input
,
int
num_bits
)
{
int64_t
num_items
=
input
.
size
(
0
);
// We utilize int64_t for the values array. (torch::kLong == int64_t)
auto
original_idx
=
torch
::
arange
(
num_items
,
input
.
options
().
dtype
(
torch
::
kLong
));
auto
sorted_array
=
torch
::
empty_like
(
input
);
auto
sorted_idx
=
torch
::
empty_like
(
original_idx
);
cuda
::
CUDAWorkspaceAllocator
allocator
;
AT_DISPATCH_INDEX_TYPES
(
input
.
scalar_type
(),
"SortImpl"
,
([
&
]
{
using
IdType
=
index_t
;
const
auto
input_keys
=
input
.
data_ptr
<
index_t
>
();
const
int64_t
*
input_values
=
original_idx
.
data_ptr
<
int64_t
>
();
IdType
*
sorted_keys
=
sorted_array
.
data_ptr
<
index_t
>
();
int64_t
*
sorted_values
=
sorted_idx
.
data_ptr
<
int64_t
>
();
cudaStream_t
stream
=
torch
::
cuda
::
getDefaultCUDAStream
();
if
(
num_bits
==
0
)
{
num_bits
=
sizeof
(
index_t
)
*
8
;
}
size_t
workspace_size
=
0
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
workspace_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
auto
temp
=
allocator
.
AllocateStorage
<
char
>
(
workspace_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
temp
.
get
(),
workspace_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
}));
return
std
::
make_pair
(
sorted_array
,
sorted_idx
);
}
/** @brief Index select operator implementation for feature size 1. */
template
<
typename
DType
,
typename
IdType
>
__global__
void
IndexSelectSingleKernel
(
...
...
@@ -115,7 +147,8 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
// Sort the index to improve the memory access pattern.
torch
::
Tensor
sorted_index
,
permutation
;
std
::
tie
(
sorted_index
,
permutation
)
=
torch
::
sort
(
index
);
std
::
tie
(
sorted_index
,
permutation
)
=
Sort
(
index
,
cuda
::
NumberOfBits
(
input_len
));
const
IdType
*
index_sorted_ptr
=
sorted_index
.
data_ptr
<
IdType
>
();
const
int64_t
*
permutation_ptr
=
permutation
.
data_ptr
<
int64_t
>
();
...
...
graphbolt/src/cuda/utils.h
View file @
d873b09a
...
...
@@ -31,6 +31,27 @@ inline int FindNumThreads(int size) {
return
ret
;
}
/**
* @brief Calculate the smallest number of bits needed to represent a given
* range of integers [0, range).
*
*/
template
<
typename
T
>
int
NumberOfBits
(
const
T
&
range
)
{
if
(
range
<=
1
)
{
// ranges of 0 or 1 require no bits to store
return
0
;
}
int
bits
=
1
;
const
auto
urange
=
static_cast
<
std
::
make_unsigned_t
<
T
>>
(
range
);
while
(
bits
<
static_cast
<
int
>
(
sizeof
(
T
)
*
8
)
&&
(
1ull
<<
bits
)
<
urange
)
{
++
bits
;
}
return
bits
;
}
}
// namespace cuda
}
// namespace graphbolt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment