Unverified Commit 9632ab1d authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Specialize non-weighted neighbor sampling impl (#7215)

parent 7129905e
...@@ -56,6 +56,8 @@ class continuous_seed { ...@@ -56,6 +56,8 @@ class continuous_seed {
c[1] = std::sin(pi * r / 2); c[1] = std::sin(pi * r / 2);
} }
uint64_t get_seed(int i) const { return s[i != 0]; }
#ifdef __CUDACC__ #ifdef __CUDACC__
__device__ inline float uniform(const uint64_t t) const { __device__ inline float uniform(const uint64_t t) const {
const uint64_t kCurandSeed = 999961; // Could be any random number. const uint64_t kCurandSeed = 999961; // Could be any random number.
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#if __CUDA_ARCH__ >= 700
#include <cuda/atomic>
#endif // __CUDA_ARCH__ >= 700
#include <limits> #include <limits>
#include <numeric> #include <numeric>
#include <type_traits> #include <type_traits>
...@@ -30,6 +33,64 @@ namespace ops { ...@@ -30,6 +33,64 @@ namespace ops {
constexpr int BLOCK_SIZE = 128; constexpr int BLOCK_SIZE = 128;
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
// To match the type of "::atomicCAS", ignore lint warning.
using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
// To match the type of "::atomicCAS", ignore lint warning.
using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}
/**
* @brief Performs neighbor sampling and fills the edge_ids array with
* original edge ids if sliced_indptr is valid. If not, then it fills the edge
* ids array with numbers upto the node degree.
*/
template <typename indptr_t, typename indices_t>
__global__ void _ComputeRandomsNS(
const int64_t num_edges, const indptr_t* const sliced_indptr,
const indptr_t* const sub_indptr, const indptr_t* const output_indptr,
const indices_t* const csr_rows, const uint64_t random_seed,
indptr_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
curandStatePhilox4_32_10_t rng;
curand_init(random_seed, i, 0, &rng);
while (i < num_edges) {
const auto row_position = csr_rows[i];
const auto row_offset = i - sub_indptr[row_position];
const auto output_offset = output_indptr[row_position];
const auto fanout = output_indptr[row_position + 1] - output_offset;
const auto rnd =
row_offset < fanout ? row_offset : curand(&rng) % (row_offset + 1);
if (rnd < fanout) {
const indptr_t edge_id =
row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0);
#if __CUDA_ARCH__ >= 700
::cuda::atomic_ref<indptr_t, ::cuda::thread_scope_device> a(
edge_ids[output_offset + rnd]);
a.fetch_max(edge_id, ::cuda::std::memory_order_relaxed);
#else
AtomicMax(edge_ids + output_offset + rnd, edge_id);
#endif // __CUDA_ARCH__
}
i += stride;
}
}
/** /**
* @brief Fills the random_arr with random numbers and the edge_ids array with * @brief Fills the random_arr with random numbers and the edge_ids array with
* original edge ids. When random_arr is sorted along with edge_ids, the first * original edge ids. When random_arr is sorted along with edge_ids, the first
...@@ -251,119 +312,186 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -251,119 +312,186 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Find the smallest integer type to store the edge id offsets. We synch // Find the smallest integer type to store the edge id offsets. We synch
// the CUDAEvent so that the access is safe. // the CUDAEvent so that the access is safe.
max_in_degree_event.synchronize(); auto compute_num_bits = [&] {
const int num_bits = max_in_degree_event.synchronize();
cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]); return cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
std::array<int, 4> type_bits = {8, 16, 32, 64}; };
const auto type_index = if (layer || probs_or_mask.has_value()) {
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) - const int num_bits = compute_num_bits();
type_bits.begin(); std::array<int, 4> type_bits = {8, 16, 32, 64};
std::array<torch::ScalarType, 5> types = { const auto type_index =
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
torch::kLong}; type_bits.begin();
auto edge_id_dtype = types[type_index]; std::array<torch::ScalarType, 5> types = {
AT_DISPATCH_INTEGRAL_TYPES( torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] { torch::kLong};
using edge_id_t = std::make_unsigned_t<scalar_t>; auto edge_id_dtype = types[type_index];
TORCH_CHECK( AT_DISPATCH_INTEGRAL_TYPES(
num_bits <= sizeof(edge_id_t) * 8, edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
"Selected edge_id_t must be capable of storing edge_ids."); using edge_id_t = std::make_unsigned_t<scalar_t>;
// Using bfloat16 for random numbers works just as reliably as TORCH_CHECK(
// float32 and provides around %30 percent speedup. num_bits <= sizeof(edge_id_t) * 8,
using rnd_t = nv_bfloat16; "Selected edge_id_t must be capable of storing edge_ids.");
auto randoms = // Using bfloat16 for random numbers works just as reliably as
allocator.AllocateStorage<rnd_t>(num_edges.value()); // float32 and provides around 30% speedup.
auto randoms_sorted = using rnd_t = nv_bfloat16;
allocator.AllocateStorage<rnd_t>(num_edges.value()); auto randoms =
auto edge_id_segments = allocator.AllocateStorage<rnd_t>(num_edges.value());
allocator.AllocateStorage<edge_id_t>(num_edges.value()); auto randoms_sorted =
auto sorted_edge_id_segments = allocator.AllocateStorage<rnd_t>(num_edges.value());
allocator.AllocateStorage<edge_id_t>(num_edges.value()); auto edge_id_segments =
AT_DISPATCH_INDEX_TYPES( allocator.AllocateStorage<edge_id_t>(num_edges.value());
indices.scalar_type(), "SampleNeighborsIndices", ([&] { auto sorted_edge_id_segments =
using indices_t = index_t; allocator.AllocateStorage<edge_id_t>(num_edges.value());
auto probs_or_mask_scalar_type = torch::kFloat32; AT_DISPATCH_INDEX_TYPES(
if (probs_or_mask.has_value()) { indices.scalar_type(), "SampleNeighborsIndices", ([&] {
probs_or_mask_scalar_type = using indices_t = index_t;
probs_or_mask.value().scalar_type(); auto probs_or_mask_scalar_type = torch::kFloat32;
} if (probs_or_mask.has_value()) {
GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask_scalar_type =
probs_or_mask_scalar_type, "SampleNeighborsProbs", probs_or_mask.value().scalar_type();
([&] { }
using probs_t = scalar_t; GRAPHBOLT_DISPATCH_ALL_TYPES(
probs_t* sliced_probs_ptr = nullptr; probs_or_mask_scalar_type, "SampleNeighborsProbs",
if (sliced_probs_or_mask.has_value()) { ([&] {
sliced_probs_ptr = sliced_probs_or_mask.value() using probs_t = scalar_t;
.data_ptr<probs_t>(); probs_t* sliced_probs_ptr = nullptr;
} if (sliced_probs_or_mask.has_value()) {
const indices_t* indices_ptr = sliced_probs_ptr = sliced_probs_or_mask.value()
layer ? indices.data_ptr<indices_t>() : nullptr; .data_ptr<probs_t>();
const dim3 block(BLOCK_SIZE); }
const dim3 grid( const indices_t* indices_ptr =
(num_edges.value() + BLOCK_SIZE - 1) / layer ? indices.data_ptr<indices_t>() : nullptr;
BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
// Compute row and random number pairs. const dim3 grid(
CUDA_KERNEL_CALL( (num_edges.value() + BLOCK_SIZE - 1) /
_ComputeRandoms, grid, block, 0, BLOCK_SIZE);
num_edges.value(), // Compute row and random number pairs.
sliced_indptr.data_ptr<indptr_t>(), CUDA_KERNEL_CALL(
sub_indptr.data_ptr<indptr_t>(), _ComputeRandoms, grid, block, 0,
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr, num_edges.value(),
indices_ptr, random_seed, randoms.get(), sliced_indptr.data_ptr<indptr_t>(),
edge_id_segments.get()); sub_indptr.data_ptr<indptr_t>(),
})); coo_rows.data_ptr<indices_t>(),
})); sliced_probs_ptr, indices_ptr, random_seed,
randoms.get(), edge_id_segments.get());
// Sort the random numbers along with edge ids, after }));
// sorting the first fanout elements of each row will }));
// give us the sampled edges.
CUB_CALL( // Sort the random numbers along with edge ids, after
DeviceSegmentedSort::SortPairs, randoms.get(), // sorting the first fanout elements of each row will
randoms_sorted.get(), edge_id_segments.get(), // give us the sampled edges.
sorted_edge_id_segments.get(), num_edges.value(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1);
picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges),
sub_indptr.options());
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it);
}
auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
// Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL( CUB_CALL(
DeviceCopy::Batched, input_buffer_it + i, DeviceSegmentedSort::SortPairs, randoms.get(),
output_buffer_it + i, sampled_degree + i, randoms_sorted.get(), edge_id_segments.get(),
std::min(num_rows - i, max_copy_at_once)); sorted_edge_id_segments.get(), num_edges.value(), num_rows,
} sub_indptr.data_ptr<indptr_t>(),
})); sub_indptr.data_ptr<indptr_t>() + 1);
picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges),
sub_indptr.options());
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in
// sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota,
SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it);
}
auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
// Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL(
DeviceCopy::Batched, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once));
}
}));
} else { // Non-weighted neighbor sampling.
picked_eids = torch::zeros(num_edges.value(), sub_indptr.options());
const auto sort_needed = type_per_edge && fanouts.size() == 1;
const auto sliced_indptr_ptr =
sort_needed ? nullptr : sliced_indptr.data_ptr<indptr_t>();
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(std::min(num_edges.value(), static_cast<int64_t>(1 << 20)) +
BLOCK_SIZE - 1) /
BLOCK_SIZE);
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = index_t;
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandomsNS, grid, block, 0, num_edges.value(),
sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0),
picked_eids.data_ptr<indptr_t>());
}));
picked_eids =
picked_eids.slice(0, 0, static_cast<indptr_t>(num_sampled_edges));
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (sort_needed) {
const int num_bits = compute_num_bits();
std::array<int, 4> type_bits = {8, 15, 31, 63};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
torch::kLong};
auto edge_id_dtype = types[type_index];
AT_DISPATCH_INTEGRAL_TYPES(
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
using edge_id_t = scalar_t;
TORCH_CHECK(
num_bits <= sizeof(edge_id_t) * 8,
"Selected edge_id_t must be capable of storing "
"edge_ids.");
auto picked_offsets = picked_eids.to(edge_id_dtype);
auto sorted_offsets = torch::empty_like(picked_offsets);
CUB_CALL(
DeviceSegmentedSort::SortKeys,
picked_offsets.data_ptr<edge_id_t>(),
sorted_offsets.data_ptr<edge_id_t>(), picked_eids.size(0),
num_rows, output_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>() + 1);
auto edge_id_offsets = ExpandIndptrImpl(
output_indptr, picked_eids.scalar_type(), sliced_indptr,
picked_eids.size(0));
picked_eids = sorted_offsets.to(picked_eids.scalar_type()) +
edge_id_offsets;
}));
}
}
output_indices = torch::empty( output_indices = torch::empty(
picked_eids.size(0), picked_eids.size(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment