"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ed616bd8a8740927770eebe017aedb6204c6105f"
Unverified Commit fd4ce7cc authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Optimize `gb.isin` and refactor sort use in codebase (#6840)

parent 0cb309a1
...@@ -7,22 +7,49 @@ ...@@ -7,22 +7,49 @@
#include <torch/script.h> #include <torch/script.h>
#include <type_traits>
namespace graphbolt { namespace graphbolt {
namespace ops { namespace ops {
/** /**
* @brief Sorts the given input and also returns the original indexes. * @brief Sorts the given input and optionally returns the original indexes.
*
* @param input A pointer to storage containing IDs.
* @param num_items Size of the input storage.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return
* - A tuple of tensors if return_original_positions is true, where the first
* one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/
template <bool return_original_positions, typename scalar_t>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(const scalar_t* input, int64_t num_items, int num_bits);
/**
* @brief Sorts the given input and optionally returns the original indexes.
* *
* @param input A tensor containing IDs. * @param input A tensor containing IDs.
* @param num_bits An integer such that all elements of input tensor are * @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits). * are less than (1 << num_bits).
* *
* @return * @return
* - A tuple of tensors, the first one includes sorted input, the second * - A tuple of tensors if return_original_positions is true, where the first
* contains original positions of the sorted result. * one includes sorted input, the second contains original positions of the
* sorted result. If return_original_positions is false, then returns only the
* sorted input.
*/ */
std::pair<torch::Tensor, torch::Tensor> Sort( template <bool return_original_positions = true>
torch::Tensor input, int num_bits = 0); std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(torch::Tensor input, int num_bits = 0);
/** /**
* @brief Tests if each element of elements is in test_elements. Returns a * @brief Tests if each element of elements is in test_elements. Returns a
......
...@@ -15,7 +15,7 @@ namespace graphbolt { ...@@ -15,7 +15,7 @@ namespace graphbolt {
namespace ops { namespace ops {
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) { torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
auto sorted_test_elements = Sort(test_elements).first; auto sorted_test_elements = Sort<false>(test_elements);
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream(); auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream); const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
......
...@@ -15,36 +15,64 @@ ...@@ -15,36 +15,64 @@
namespace graphbolt { namespace graphbolt {
namespace ops { namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort( template <bool return_original_positions, typename scalar_t>
torch::Tensor input, int num_bits) { std::conditional_t<
int64_t num_items = input.size(0); return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
// We utilize int64_t for the values array. (torch::kLong == int64_t) torch::Tensor>
auto original_idx = Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
torch::arange(num_items, input.options().dtype(torch::kLong)); const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);
auto sorted_array = torch::empty_like(input);
auto sorted_idx = torch::empty_like(original_idx);
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream(); auto stream = cuda::GetCurrentStream();
AT_DISPATCH_INDEX_TYPES( constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;
input.scalar_type(), "SortImpl", ([&] { auto sorted_array = torch::empty(num_items, options.dtype(dtype));
const auto input_keys = input.data_ptr<index_t>(); auto sorted_keys = sorted_array.data_ptr<scalar_t>();
const int64_t* input_values = original_idx.data_ptr<int64_t>(); if (num_bits == 0) {
index_t* sorted_keys = sorted_array.data_ptr<index_t>(); num_bits = sizeof(scalar_t) * 8;
int64_t* sorted_values = sorted_idx.data_ptr<int64_t>(); }
if (num_bits == 0) {
num_bits = sizeof(index_t) * 8; if constexpr (return_original_positions) {
} // We utilize int64_t for the values array. (torch::kLong == int64_t)
size_t tmp_storage_size = 0; auto original_idx = torch::arange(num_items, options.dtype(torch::kLong));
CUDA_CALL(cub::DeviceRadixSort::SortPairs( auto sorted_idx = torch::empty_like(original_idx);
nullptr, tmp_storage_size, input_keys, sorted_keys, input_values, const int64_t* input_values = original_idx.data_ptr<int64_t>();
sorted_values, num_items, 0, num_bits, stream)); int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size); size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs( CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys, nullptr, tmp_storage_size, input_keys, sorted_keys, input_values,
input_values, sorted_values, num_items, 0, num_bits, stream)); sorted_values, num_items, 0, num_bits, stream));
})); auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
return std::make_pair(sorted_array, sorted_idx); CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys,
input_values, sorted_values, num_items, 0, num_bits, stream));
return std::make_pair(sorted_array, sorted_idx);
} else {
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, tmp_storage_size, input_keys, sorted_keys, num_items, 0,
num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys, num_items,
0, num_bits, stream));
return sorted_array;
}
}
template <bool return_original_positions>
std::conditional_t<
return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
torch::Tensor>
Sort(torch::Tensor input, int num_bits) {
return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "SortImpl", ([&] {
return Sort<return_original_positions>(
input.data_ptr<scalar_t>(),
input.size(0), num_bits);
}));
} }
template torch::Tensor Sort<false>(torch::Tensor input, int num_bits);
template std::pair<torch::Tensor, torch::Tensor> Sort<true>(
torch::Tensor input, int num_bits);
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
...@@ -63,69 +63,58 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -63,69 +63,58 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Sort the unique_dst_ids tensor. // Sort the unique_dst_ids tensor.
auto sorted_unique_dst_ids = auto sorted_unique_dst_ids =
allocator.AllocateStorage<scalar_t>(unique_dst_ids.size(0)); Sort<false>(unique_dst_ids_ptr, unique_dst_ids.size(0), num_bits);
{ auto sorted_unique_dst_ids_ptr =
size_t workspace_size; sorted_unique_dst_ids.data_ptr<scalar_t>();
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
}
// Mark dst nodes in the src_ids tensor. // Mark dst nodes in the src_ids tensor.
auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0)); auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
thrust::binary_search( thrust::binary_search(
exec_policy, sorted_unique_dst_ids.get(), exec_policy, sorted_unique_dst_ids_ptr,
sorted_unique_dst_ids.get() + unique_dst_ids.size(0), src_ids_ptr, sorted_unique_dst_ids_ptr + unique_dst_ids.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), is_dst.get()); src_ids_ptr + src_ids.size(0), is_dst.get());
// Filter the non-dst nodes in the src_ids tensor, hence only_src. // Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto only_src = allocator.AllocateStorage<scalar_t>(src_ids.size(0)); auto only_src =
auto only_src_size = torch::empty(src_ids.size(0), sorted_unique_dst_ids.options());
thrust::remove_copy_if( {
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0), auto only_src_size =
is_dst.get(), only_src.get(), thrust::identity<bool>{}) - thrust::remove_copy_if(
only_src.get(); exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
auto sorted_only_src = is_dst.get(), only_src.data_ptr<scalar_t>(),
allocator.AllocateStorage<scalar_t>(only_src_size); thrust::identity<bool>{}) -
only_src.data_ptr<scalar_t>();
{ // Sort the only_src tensor so that we can unique it with Encode only_src = only_src.slice(0, 0, only_src_size);
// operation later.
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
} }
auto unique_only_src = torch::empty(only_src_size, src_ids.options()); // Sort the only_src tensor so that we can unique it with Encode
// operation later.
auto sorted_only_src = Sort<false>(
only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits);
auto unique_only_src =
torch::empty(only_src.size(0), src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>(); auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1); auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);
{ // Compute the unique operation on the only_src tensor. { // Compute the unique operation on the only_src tensor.
size_t workspace_size; size_t workspace_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode( CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
nullptr, workspace_size, sorted_only_src.get(), nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{}, unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream)); unique_only_src_cnt.get(), only_src.size(0), stream));
auto temp = allocator.AllocateStorage<char>(workspace_size); auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode( CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
temp.get(), workspace_size, sorted_only_src.get(), temp.get(), workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{}, unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream)); unique_only_src_cnt.get(), only_src.size(0), stream));
auto unique_only_src_size =
cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
} }
auto unique_only_src_size = cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
auto real_order = torch::cat({unique_dst_ids, unique_only_src}); auto real_order = torch::cat({unique_dst_ids, unique_only_src});
// Sort here so that binary search can be used to lookup new_ids. // Sort here so that binary search can be used to lookup new_ids.
auto [sorted_order, new_ids] = Sort(real_order, num_bits); auto [sorted_order, new_ids] = Sort(real_order, num_bits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment