"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "401e1278fed12899b21b19d3dd3f95d138a912c7"
Unverified Commit fd4ce7cc authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Optimize `gb.isin` and refactor sort use in codebase (#6840)

parent 0cb309a1
......@@ -7,22 +7,49 @@
#include <torch/script.h>
#include <type_traits>
namespace graphbolt {
namespace ops {
/**
* @brief Sorts the given input and also returns the original indexes.
* @brief Sorts the given input and optionally returns the original indexes.
*
* @param input A pointer to storage containing IDs.
* @param num_items Size of the input storage.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return
* - A tuple of tensors if return_original_positions is true, where the first
* one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/
template <bool return_original_positions, typename scalar_t>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(const scalar_t* input, int64_t num_items, int num_bits);
/**
* @brief Sorts the given input and optionally returns the original indexes.
*
* @param input A tensor containing IDs.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return
* - A tuple of tensors, the first one includes sorted input, the second
* contains original positions of the sorted result.
* - A tuple of tensors if return_original_positions is true, where the first
* one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/
std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits = 0);
template <bool return_original_positions = true>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(torch::Tensor input, int num_bits = 0);
/**
* @brief Tests if each element of elements is in test_elements. Returns a
......
......@@ -15,7 +15,7 @@ namespace graphbolt {
namespace ops {
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
auto sorted_test_elements = Sort(test_elements).first;
auto sorted_test_elements = Sort<false>(test_elements);
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
......
......@@ -15,36 +15,64 @@
namespace graphbolt {
namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits) {
int64_t num_items = input.size(0);
// We utilize int64_t for the values array. (torch::kLong == int64_t)
auto original_idx =
torch::arange(num_items, input.options().dtype(torch::kLong));
auto sorted_array = torch::empty_like(input);
auto sorted_idx = torch::empty_like(original_idx);
template <bool return_original_positions, typename scalar_t>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
AT_DISPATCH_INDEX_TYPES(
input.scalar_type(), "SortImpl", ([&] {
const auto input_keys = input.data_ptr<index_t>();
const int64_t* input_values = original_idx.data_ptr<int64_t>();
index_t* sorted_keys = sorted_array.data_ptr<index_t>();
int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
if (num_bits == 0) {
num_bits = sizeof(index_t) * 8;
}
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
nullptr, tmp_storage_size, input_keys, sorted_keys, input_values,
sorted_values, num_items, 0, num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys,
input_values, sorted_values, num_items, 0, num_bits, stream));
}));
return std::make_pair(sorted_array, sorted_idx);
constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;
auto sorted_array = torch::empty(num_items, options.dtype(dtype));
auto sorted_keys = sorted_array.data_ptr<scalar_t>();
if (num_bits == 0) {
num_bits = sizeof(scalar_t) * 8;
}
if constexpr (return_original_positions) {
// We utilize int64_t for the values array. (torch::kLong == int64_t)
auto original_idx = torch::arange(num_items, options.dtype(torch::kLong));
auto sorted_idx = torch::empty_like(original_idx);
const int64_t* input_values = original_idx.data_ptr<int64_t>();
int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
nullptr, tmp_storage_size, input_keys, sorted_keys, input_values,
sorted_values, num_items, 0, num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys,
input_values, sorted_values, num_items, 0, num_bits, stream));
return std::make_pair(sorted_array, sorted_idx);
} else {
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, tmp_storage_size, input_keys, sorted_keys, num_items, 0,
num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys, num_items,
0, num_bits, stream));
return sorted_array;
}
}
template <bool return_original_positions>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(torch::Tensor input, int num_bits) {
return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "SortImpl", ([&] {
return Sort<return_original_positions>(
input.data_ptr<scalar_t>(),
input.size(0), num_bits);
}));
}
template torch::Tensor Sort<false>(torch::Tensor input, int num_bits);
template std::pair<torch::Tensor, torch::Tensor> Sort<true>(
torch::Tensor input, int num_bits);
} // namespace ops
} // namespace graphbolt
......@@ -63,69 +63,58 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Sort the unique_dst_ids tensor.
auto sorted_unique_dst_ids =
allocator.AllocateStorage<scalar_t>(unique_dst_ids.size(0));
{
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
}
Sort<false>(unique_dst_ids_ptr, unique_dst_ids.size(0), num_bits);
auto sorted_unique_dst_ids_ptr =
sorted_unique_dst_ids.data_ptr<scalar_t>();
// Mark dst nodes in the src_ids tensor.
auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
thrust::binary_search(
exec_policy, sorted_unique_dst_ids.get(),
sorted_unique_dst_ids.get() + unique_dst_ids.size(0), src_ids_ptr,
exec_policy, sorted_unique_dst_ids_ptr,
sorted_unique_dst_ids_ptr + unique_dst_ids.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), is_dst.get());
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto only_src = allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto only_src_size =
thrust::remove_copy_if(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
is_dst.get(), only_src.get(), thrust::identity<bool>{}) -
only_src.get();
auto sorted_only_src =
allocator.AllocateStorage<scalar_t>(only_src_size);
{ // Sort the only_src tensor so that we can unique it with Encode
// operation later.
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
auto only_src =
torch::empty(src_ids.size(0), sorted_unique_dst_ids.options());
{
auto only_src_size =
thrust::remove_copy_if(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
is_dst.get(), only_src.data_ptr<scalar_t>(),
thrust::identity<bool>{}) -
only_src.data_ptr<scalar_t>();
only_src = only_src.slice(0, 0, only_src_size);
}
auto unique_only_src = torch::empty(only_src_size, src_ids.options());
// Sort the only_src tensor so that we can unique it with Encode
// operation later.
auto sorted_only_src = Sort<false>(
only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits);
auto unique_only_src =
torch::empty(only_src.size(0), src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);
{ // Compute the unique operation on the only_src tensor.
size_t workspace_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
nullptr, workspace_size, sorted_only_src.get(),
nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
unique_only_src_cnt.get(), only_src.size(0), stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
temp.get(), workspace_size, sorted_only_src.get(),
temp.get(), workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
unique_only_src_cnt.get(), only_src.size(0), stream));
auto unique_only_src_size =
cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
}
auto unique_only_src_size = cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
auto real_order = torch::cat({unique_dst_ids, unique_only_src});
// Sort here so that binary search can be used to lookup new_ids.
auto [sorted_order, new_ids] = Sort(real_order, num_bits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment