Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fd4ce7cc
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "401e1278fed12899b21b19d3dd3f95d138a912c7"
Unverified
Commit
fd4ce7cc
authored
Dec 27, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Dec 27, 2023
Browse files
[GraphBolt][CUDA] Optimize `gb.isin` and refactor sort use in codebase (#6840)
parent
0cb309a1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
75 deletions
+119
-75
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+32
-5
graphbolt/src/cuda/isin.cu
graphbolt/src/cuda/isin.cu
+1
-1
graphbolt/src/cuda/sort_impl.cu
graphbolt/src/cuda/sort_impl.cu
+55
-27
graphbolt/src/cuda/unique_and_compact_impl.cu
graphbolt/src/cuda/unique_and_compact_impl.cu
+31
-42
No files found.
graphbolt/include/graphbolt/cuda_ops.h
View file @
fd4ce7cc
...
...
@@ -7,22 +7,49 @@
#include <torch/script.h>
#include <type_traits>
namespace
graphbolt
{
namespace
ops
{
/**
* @brief Sorts the given input and also returns the original indexes.
* @brief Sorts the given input and optionally returns the original indexes.
*
* @param input A pointer to storage containing IDs.
* @param num_items Size of the input storage.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return
* - A tuple of tensors if return_original_positions is true, where the first
* one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/
template
<
bool
return_original_positions
,
typename
scalar_t
>
std
::
conditional_t
<
return_original_positions
,
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
,
torch
::
Tensor
>
Sort
(
const
scalar_t
*
input
,
int64_t
num_items
,
int
num_bits
);
/**
* @brief Sorts the given input and optionally returns the original indexes.
*
* @param input A tensor containing IDs.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return
* - A tuple of tensors, the first one includes sorted input, the second
* contains original positions of the sorted result.
* - A tuple of tensors if return_original_positions is true, where the first
* one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
Sort
(
torch
::
Tensor
input
,
int
num_bits
=
0
);
template
<
bool
return_original_positions
=
true
>
std
::
conditional_t
<
return_original_positions
,
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
,
torch
::
Tensor
>
Sort
(
torch
::
Tensor
input
,
int
num_bits
=
0
);
/**
* @brief Tests if each element of elements is in test_elements. Returns a
...
...
graphbolt/src/cuda/isin.cu
View file @
fd4ce7cc
...
...
@@ -15,7 +15,7 @@ namespace graphbolt {
namespace
ops
{
torch
::
Tensor
IsIn
(
torch
::
Tensor
elements
,
torch
::
Tensor
test_elements
)
{
auto
sorted_test_elements
=
Sort
(
test_elements
)
.
first
;
auto
sorted_test_elements
=
Sort
<
false
>
(
test_elements
);
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
cuda
::
GetCurrentStream
();
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
stream
);
...
...
graphbolt/src/cuda/sort_impl.cu
View file @
fd4ce7cc
...
...
@@ -15,36 +15,64 @@
namespace
graphbolt
{
namespace
ops
{
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
Sort
(
torch
::
Tensor
input
,
int
num_bits
)
{
int64_t
num_items
=
input
.
size
(
0
);
// We utilize int64_t for the values array. (torch::kLong == int64_t)
auto
original_idx
=
torch
::
arange
(
num_items
,
input
.
options
().
dtype
(
torch
::
kLong
));
auto
sorted_array
=
torch
::
empty_like
(
input
);
auto
sorted_idx
=
torch
::
empty_like
(
original_idx
);
template
<
bool
return_original_positions
,
typename
scalar_t
>
std
::
conditional_t
<
return_original_positions
,
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
,
torch
::
Tensor
>
Sort
(
const
scalar_t
*
input_keys
,
int64_t
num_items
,
int
num_bits
)
{
const
auto
options
=
torch
::
TensorOptions
().
device
(
c10
::
DeviceType
::
CUDA
);
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
cuda
::
GetCurrentStream
();
AT_DISPATCH_INDEX_TYPES
(
input
.
scalar_type
(),
"SortImpl"
,
([
&
]
{
const
auto
input_keys
=
input
.
data_ptr
<
index_t
>
();
const
int64_t
*
input_values
=
original_idx
.
data_ptr
<
int64_t
>
();
index_t
*
sorted_keys
=
sorted_array
.
data_ptr
<
index_t
>
();
int64_t
*
sorted_values
=
sorted_idx
.
data_ptr
<
int64_t
>
();
if
(
num_bits
==
0
)
{
num_bits
=
sizeof
(
index_t
)
*
8
;
}
size_t
tmp_storage_size
=
0
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
tmp_storage_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
tmp_storage
.
get
(),
tmp_storage_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
}));
return
std
::
make_pair
(
sorted_array
,
sorted_idx
);
constexpr
c10
::
ScalarType
dtype
=
c10
::
CppTypeToScalarType
<
scalar_t
>::
value
;
auto
sorted_array
=
torch
::
empty
(
num_items
,
options
.
dtype
(
dtype
));
auto
sorted_keys
=
sorted_array
.
data_ptr
<
scalar_t
>
();
if
(
num_bits
==
0
)
{
num_bits
=
sizeof
(
scalar_t
)
*
8
;
}
if
constexpr
(
return_original_positions
)
{
// We utilize int64_t for the values array. (torch::kLong == int64_t)
auto
original_idx
=
torch
::
arange
(
num_items
,
options
.
dtype
(
torch
::
kLong
));
auto
sorted_idx
=
torch
::
empty_like
(
original_idx
);
const
int64_t
*
input_values
=
original_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
sorted_values
=
sorted_idx
.
data_ptr
<
int64_t
>
();
size_t
tmp_storage_size
=
0
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
tmp_storage_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortPairs
(
tmp_storage
.
get
(),
tmp_storage_size
,
input_keys
,
sorted_keys
,
input_values
,
sorted_values
,
num_items
,
0
,
num_bits
,
stream
));
return
std
::
make_pair
(
sorted_array
,
sorted_idx
);
}
else
{
size_t
tmp_storage_size
=
0
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
nullptr
,
tmp_storage_size
,
input_keys
,
sorted_keys
,
num_items
,
0
,
num_bits
,
stream
));
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
tmp_storage
.
get
(),
tmp_storage_size
,
input_keys
,
sorted_keys
,
num_items
,
0
,
num_bits
,
stream
));
return
sorted_array
;
}
}
template
<
bool
return_original_positions
>
std
::
conditional_t
<
return_original_positions
,
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
,
torch
::
Tensor
>
Sort
(
torch
::
Tensor
input
,
int
num_bits
)
{
return
AT_DISPATCH_INTEGRAL_TYPES
(
input
.
scalar_type
(),
"SortImpl"
,
([
&
]
{
return
Sort
<
return_original_positions
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
size
(
0
),
num_bits
);
}));
}
template
torch
::
Tensor
Sort
<
false
>(
torch
::
Tensor
input
,
int
num_bits
);
template
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
Sort
<
true
>
(
torch
::
Tensor
input
,
int
num_bits
);
}
// namespace ops
}
// namespace graphbolt
graphbolt/src/cuda/unique_and_compact_impl.cu
View file @
fd4ce7cc
...
...
@@ -63,69 +63,58 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Sort the unique_dst_ids tensor.
auto
sorted_unique_dst_ids
=
allocator
.
AllocateStorage
<
scalar_t
>
(
unique_dst_ids
.
size
(
0
));
{
size_t
workspace_size
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
nullptr
,
workspace_size
,
unique_dst_ids_ptr
,
sorted_unique_dst_ids
.
get
(),
unique_dst_ids
.
size
(
0
),
0
,
num_bits
,
stream
));
auto
temp
=
allocator
.
AllocateStorage
<
char
>
(
workspace_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
temp
.
get
(),
workspace_size
,
unique_dst_ids_ptr
,
sorted_unique_dst_ids
.
get
(),
unique_dst_ids
.
size
(
0
),
0
,
num_bits
,
stream
));
}
Sort
<
false
>
(
unique_dst_ids_ptr
,
unique_dst_ids
.
size
(
0
),
num_bits
);
auto
sorted_unique_dst_ids_ptr
=
sorted_unique_dst_ids
.
data_ptr
<
scalar_t
>
();
// Mark dst nodes in the src_ids tensor.
auto
is_dst
=
allocator
.
AllocateStorage
<
bool
>
(
src_ids
.
size
(
0
));
thrust
::
binary_search
(
exec_policy
,
sorted_unique_dst_ids
.
get
()
,
sorted_unique_dst_ids
.
get
()
+
unique_dst_ids
.
size
(
0
),
src_ids_ptr
,
exec_policy
,
sorted_unique_dst_ids
_ptr
,
sorted_unique_dst_ids
_ptr
+
unique_dst_ids
.
size
(
0
),
src_ids_ptr
,
src_ids_ptr
+
src_ids
.
size
(
0
),
is_dst
.
get
());
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto
only_src
=
allocator
.
AllocateStorage
<
scalar_t
>
(
src_ids
.
size
(
0
));
auto
only_src_size
=
thrust
::
remove_copy_if
(
exec_policy
,
src_ids_ptr
,
src_ids_ptr
+
src_ids
.
size
(
0
),
is_dst
.
get
(),
only_src
.
get
(),
thrust
::
identity
<
bool
>
{})
-
only_src
.
get
();
auto
sorted_only_src
=
allocator
.
AllocateStorage
<
scalar_t
>
(
only_src_size
);
{
// Sort the only_src tensor so that we can unique it with Encode
// operation later.
size_t
workspace_size
;
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
nullptr
,
workspace_size
,
only_src
.
get
(),
sorted_only_src
.
get
(),
only_src_size
,
0
,
num_bits
,
stream
));
auto
temp
=
allocator
.
AllocateStorage
<
char
>
(
workspace_size
);
CUDA_CALL
(
cub
::
DeviceRadixSort
::
SortKeys
(
temp
.
get
(),
workspace_size
,
only_src
.
get
(),
sorted_only_src
.
get
(),
only_src_size
,
0
,
num_bits
,
stream
));
auto
only_src
=
torch
::
empty
(
src_ids
.
size
(
0
),
sorted_unique_dst_ids
.
options
());
{
auto
only_src_size
=
thrust
::
remove_copy_if
(
exec_policy
,
src_ids_ptr
,
src_ids_ptr
+
src_ids
.
size
(
0
),
is_dst
.
get
(),
only_src
.
data_ptr
<
scalar_t
>
(),
thrust
::
identity
<
bool
>
{})
-
only_src
.
data_ptr
<
scalar_t
>
();
only_src
=
only_src
.
slice
(
0
,
0
,
only_src_size
);
}
auto
unique_only_src
=
torch
::
empty
(
only_src_size
,
src_ids
.
options
());
// Sort the only_src tensor so that we can unique it with Encode
// operation later.
auto
sorted_only_src
=
Sort
<
false
>
(
only_src
.
data_ptr
<
scalar_t
>
(),
only_src
.
size
(
0
),
num_bits
);
auto
unique_only_src
=
torch
::
empty
(
only_src
.
size
(
0
),
src_ids
.
options
());
auto
unique_only_src_ptr
=
unique_only_src
.
data_ptr
<
scalar_t
>
();
auto
unique_only_src_cnt
=
allocator
.
AllocateStorage
<
scalar_t
>
(
1
);
{
// Compute the unique operation on the only_src tensor.
size_t
workspace_size
;
CUDA_CALL
(
cub
::
DeviceRunLengthEncode
::
Encode
(
nullptr
,
workspace_size
,
sorted_only_src
.
get
(),
nullptr
,
workspace_size
,
sorted_only_src
.
data_ptr
<
scalar_t
>
(),
unique_only_src_ptr
,
cub
::
DiscardOutputIterator
{},
unique_only_src_cnt
.
get
(),
only_src
_
size
,
stream
));
unique_only_src_cnt
.
get
(),
only_src
.
size
(
0
)
,
stream
));
auto
temp
=
allocator
.
AllocateStorage
<
char
>
(
workspace_size
);
CUDA_CALL
(
cub
::
DeviceRunLengthEncode
::
Encode
(
temp
.
get
(),
workspace_size
,
sorted_only_src
.
get
(),
temp
.
get
(),
workspace_size
,
sorted_only_src
.
data_ptr
<
scalar_t
>
(),
unique_only_src_ptr
,
cub
::
DiscardOutputIterator
{},
unique_only_src_cnt
.
get
(),
only_src_size
,
stream
));
unique_only_src_cnt
.
get
(),
only_src
.
size
(
0
),
stream
));
auto
unique_only_src_size
=
cuda
::
CopyScalar
(
unique_only_src_cnt
.
get
());
unique_only_src
=
unique_only_src
.
slice
(
0
,
0
,
static_cast
<
scalar_t
>
(
unique_only_src_size
));
}
auto
unique_only_src_size
=
cuda
::
CopyScalar
(
unique_only_src_cnt
.
get
());
unique_only_src
=
unique_only_src
.
slice
(
0
,
0
,
static_cast
<
scalar_t
>
(
unique_only_src_size
));
auto
real_order
=
torch
::
cat
({
unique_dst_ids
,
unique_only_src
});
// Sort here so that binary search can be used to lookup new_ids.
auto
[
sorted_order
,
new_ids
]
=
Sort
(
real_order
,
num_bits
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment