Unverified Commit 22a2513d authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Reduce and hide `unique_and_compact` synchronizations. (#6841)

parent 08569139
...@@ -104,15 +104,24 @@ inline bool is_zero<dim3>(dim3 size) { ...@@ -104,15 +104,24 @@ inline bool is_zero<dim3>(dim3 size) {
*/ */
template <typename scalar_t> template <typename scalar_t>
struct CopyScalar { struct CopyScalar {
CopyScalar(const scalar_t* device_ptr) : is_ready_(false) { CopyScalar() : is_ready_(true) { init_pinned_storage(); }
pinned_scalar_ = torch::empty(
sizeof(scalar_t), void record(at::cuda::CUDAStream stream = GetCurrentStream()) {
c10::TensorOptions().dtype(torch::kBool).pinned_memory(true)); copy_event_.record(stream);
is_ready_ = false;
}
scalar_t* get() {
return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr());
}
CopyScalar(const scalar_t* device_ptr) {
init_pinned_storage();
auto stream = GetCurrentStream(); auto stream = GetCurrentStream();
CUDA_CALL(cudaMemcpyAsync( CUDA_CALL(cudaMemcpyAsync(
reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr, reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr,
sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
copy_event_.record(stream); record(stream);
} }
operator scalar_t() { operator scalar_t() {
...@@ -120,10 +129,16 @@ struct CopyScalar { ...@@ -120,10 +129,16 @@ struct CopyScalar {
copy_event_.synchronize(); copy_event_.synchronize();
is_ready_ = true; is_ready_ = true;
} }
return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr())[0]; return *get();
} }
private: private:
void init_pinned_storage() {
pinned_scalar_ = torch::empty(
sizeof(scalar_t),
c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));
}
torch::Tensor pinned_scalar_; torch::Tensor pinned_scalar_;
at::cuda::CUDAEvent copy_event_; at::cuda::CUDAEvent copy_event_;
bool is_ready_; bool is_ready_;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <thrust/remove.h> #include <thrust/remove.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <type_traits>
#include "./common.h" #include "./common.h"
#include "./utils.h" #include "./utils.h"
...@@ -32,6 +33,24 @@ struct EqualityFunc { ...@@ -32,6 +33,24 @@ struct EqualityFunc {
} }
}; };
#define DefineReductionFunction(reduce_fn, name) \
template <typename scalar_iterator_t> \
auto name(const scalar_iterator_t input, int64_t size) { \
auto allocator = cuda::GetAllocator(); \
auto stream = cuda::GetCurrentStream(); \
using scalar_t = std::remove_reference_t<decltype(input[0])>; \
cuda::CopyScalar<scalar_t> result; \
size_t workspace_size = 0; \
reduce_fn(nullptr, workspace_size, input, result.get(), size, stream); \
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size); \
reduce_fn( \
tmp_storage.get(), workspace_size, input, result.get(), size, stream); \
return result; \
}
DefineReductionFunction(cub::DeviceReduce::Max, Max);
DefineReductionFunction(cub::DeviceReduce::Min, Min);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids, const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits) { const torch::Tensor unique_dst_ids, int num_bits) {
...@@ -48,17 +67,13 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -48,17 +67,13 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>(); auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>();
auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>(); auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>();
// If the given num_bits argument is not in the reasonable range, // If num_bits is not given, compute maximum vertex ids to compute
// we recompute it to speedup the expensive sort operations. // num_bits later to speedup the expensive sort operations.
if (num_bits <= 0 || num_bits > sizeof(scalar_t) * 8) { cuda::CopyScalar<scalar_t> max_id_src;
auto max_id = thrust::reduce( cuda::CopyScalar<scalar_t> max_id_dst;
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0), if (num_bits == 0) {
static_cast<scalar_t>(0), thrust::maximum<scalar_t>{}); max_id_src = Max(src_ids_ptr, src_ids.size(0));
max_id = thrust::reduce( max_id_dst = Max(unique_dst_ids_ptr, unique_dst_ids.size(0));
exec_policy, unique_dst_ids_ptr,
unique_dst_ids_ptr + unique_dst_ids.size(0), max_id,
thrust::maximum<scalar_t>{});
num_bits = cuda::NumberOfBits(max_id + 1);
} }
// Sort the unique_dst_ids tensor. // Sort the unique_dst_ids tensor.
...@@ -78,41 +93,55 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -78,41 +93,55 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto only_src = auto only_src =
torch::empty(src_ids.size(0), sorted_unique_dst_ids.options()); torch::empty(src_ids.size(0), sorted_unique_dst_ids.options());
{ {
auto only_src_size = auto is_src = thrust::make_transform_iterator(
thrust::remove_copy_if( is_dst.get(), thrust::logical_not<bool>{});
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0), cuda::CopyScalar<int64_t> only_src_size;
is_dst.get(), only_src.data_ptr<scalar_t>(), size_t workspace_size = 0;
thrust::identity<bool>{}) - cub::DeviceSelect::Flagged(
only_src.data_ptr<scalar_t>(); nullptr, workspace_size, src_ids_ptr, is_src,
only_src = only_src.slice(0, 0, only_src_size); only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream);
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
cub::DeviceSelect::Flagged(
tmp_storage.get(), workspace_size, src_ids_ptr, is_src,
only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream);
stream.synchronize();
only_src = only_src.slice(0, 0, static_cast<int64_t>(only_src_size));
} }
// Sort the only_src tensor so that we can unique it with Encode // The code block above synchronizes, ensuring safe access to max_id_src
// operation later. // and max_id_dst.
if (num_bits == 0) {
num_bits = cuda::NumberOfBits(
1 + std::max(
static_cast<scalar_t>(max_id_src),
static_cast<scalar_t>(max_id_dst)));
}
// Sort the only_src tensor so that we can unique it later.
auto sorted_only_src = Sort<false>( auto sorted_only_src = Sort<false>(
only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits); only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits);
auto unique_only_src = auto unique_only_src =
torch::empty(only_src.size(0), src_ids.options()); torch::empty(only_src.size(0), src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>(); auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);
{ // Compute the unique operation on the only_src tensor. { // Compute the unique operation on the only_src tensor.
size_t workspace_size; cuda::CopyScalar<int64_t> unique_only_src_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode( size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSelect::Unique(
nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(), nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{}, unique_only_src_ptr, unique_only_src_size.get(), only_src.size(0),
unique_only_src_cnt.get(), only_src.size(0), stream)); stream));
auto temp = allocator.AllocateStorage<char>(workspace_size); auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode( CUDA_CALL(cub::DeviceSelect::Unique(
temp.get(), workspace_size, sorted_only_src.data_ptr<scalar_t>(), tmp_storage.get(), workspace_size,
unique_only_src_ptr, cub::DiscardOutputIterator{}, sorted_only_src.data_ptr<scalar_t>(), unique_only_src_ptr,
unique_only_src_cnt.get(), only_src.size(0), stream)); unique_only_src_size.get(), only_src.size(0), stream));
stream.synchronize();
auto unique_only_src_size =
cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice( unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size)); 0, 0, static_cast<int64_t>(unique_only_src_size));
} }
auto real_order = torch::cat({unique_dst_ids, unique_only_src}); auto real_order = torch::cat({unique_dst_ids, unique_only_src});
...@@ -123,39 +152,43 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -123,39 +152,43 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Holds the found locations of the src and dst ids in the sorted_order. // Holds the found locations of the src and dst ids in the sorted_order.
// Later is used to lookup the new ids of the src_ids and dst_ids // Later is used to lookup the new ids of the src_ids and dst_ids
// tensors. // tensors.
auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto new_dst_ids_loc = auto new_dst_ids_loc =
allocator.AllocateStorage<scalar_t>(dst_ids.size(0)); allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
thrust::lower_bound( thrust::lower_bound(
exec_policy, sorted_order_ptr, exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), dst_ids_ptr, sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get()); dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
{ // Check if unique_dst_ids includes all dst_ids.
cuda::CopyScalar<bool> all_exist;
// Check if unique_dst_ids includes all dst_ids.
if (dst_ids.size(0) > 0) {
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
auto equal_it = thrust::make_transform_iterator( auto equal_it = thrust::make_transform_iterator(
iota, EqualityFunc<scalar_t>{ iota, EqualityFunc<scalar_t>{
sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr}); sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr});
auto all_exist = thrust::all_of( all_exist = Min(equal_it, dst_ids.size(0));
exec_policy, equal_it, equal_it + dst_ids.size(0), all_exist.record();
thrust::identity<bool>());
if (!all_exist) {
throw std::out_of_range("Some ids not found.");
}
} }
auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
// Finally, lookup the new compact ids of the src and dst tensors via // Finally, lookup the new compact ids of the src and dst tensors via
// gather operations. // gather operations.
auto new_src_ids = torch::empty_like(src_ids); auto new_src_ids = torch::empty_like(src_ids);
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather( thrust::gather(
exec_policy, new_src_ids_loc.get(), exec_policy, new_src_ids_loc.get(),
new_src_ids_loc.get() + src_ids.size(0), new_src_ids_loc.get() + src_ids.size(0),
new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>()); new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
// Perform check before we gather for the dst indices.
if (dst_ids.size(0) > 0 && !static_cast<bool>(all_exist)) {
throw std::out_of_range("Some ids not found.");
}
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather( thrust::gather(
exec_policy, new_dst_ids_loc.get(), exec_policy, new_dst_ids_loc.get(),
new_dst_ids_loc.get() + dst_ids.size(0), new_dst_ids_loc.get() + dst_ids.size(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment