Unverified Commit 365bb723 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Optimize UVA `index_select_csc`. (#6995)

parent a47eacc1
...@@ -41,18 +41,18 @@ struct AlignmentFunc { ...@@ -41,18 +41,18 @@ struct AlignmentFunc {
} }
}; };
template <typename indptr_t, typename indices_t> template <typename indptr_t, typename indices_t, typename coo_rows_t>
__global__ void _CopyIndicesAlignedKernel( __global__ void _CopyIndicesAlignedKernel(
const indptr_t edge_count, const int64_t num_nodes, const indptr_t edge_count, const indptr_t* const indptr,
const indptr_t* const indptr, const indptr_t* const output_indptr, const indptr_t* const output_indptr,
const indptr_t* const output_indptr_aligned, const indices_t* const indices, const indptr_t* const output_indptr_aligned, const indices_t* const indices,
indices_t* const output_indices, const int64_t* const perm) { const coo_rows_t* const coo_aligned_rows, indices_t* const output_indices,
const int64_t* const perm) {
indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x; indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (idx < edge_count) { while (idx < edge_count) {
const auto permuted_row_pos = const auto permuted_row_pos = coo_aligned_rows[idx];
cuda::UpperBound(output_indptr_aligned, num_nodes, idx) - 1;
const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos; const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos;
const auto out_row = output_indptr[row_pos]; const auto out_row = output_indptr[row_pos];
const auto d = output_indptr[row_pos + 1] - out_row; const auto d = output_indptr[row_pos + 1] - out_row;
...@@ -97,7 +97,8 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -97,7 +97,8 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type)); torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
auto output_indptr_aligned = auto output_indptr_aligned =
allocator.AllocateStorage<indptr_t>(num_nodes + 1); torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
auto output_indptr_aligned_ptr = output_indptr_aligned.data_ptr<indptr_t>();
{ {
// Returns the actual and modified_indegree as a pair, the // Returns the actual and modified_indegree as a pair, the
...@@ -106,7 +107,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -106,7 +107,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
auto modified_in_degree = thrust::make_transform_iterator( auto modified_in_degree = thrust::make_transform_iterator(
iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes}); iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes});
auto output_indptr_pair = thrust::make_zip_iterator( auto output_indptr_pair = thrust::make_zip_iterator(
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get()); output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr);
thrust::tuple<indptr_t, indptr_t> zero_value{}; thrust::tuple<indptr_t, indptr_t> zero_value{};
// Compute the prefix sum over actual and modified indegrees. // Compute the prefix sum over actual and modified indegrees.
CUB_CALL( CUB_CALL(
...@@ -121,25 +122,42 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -121,25 +122,42 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
output_size = static_cast<indptr_t>(edge_count); output_size = static_cast<indptr_t>(edge_count);
} }
// Copy the modified number of edges. // Copy the modified number of edges.
auto edge_count_aligned = auto edge_count_aligned_ =
cuda::CopyScalar{output_indptr_aligned.get() + num_nodes}; cuda::CopyScalar{output_indptr_aligned_ptr + num_nodes};
const int64_t edge_count_aligned = static_cast<indptr_t>(edge_count_aligned_);
// Allocate output array with actual number of edges. // Allocate output array with actual number of edges.
torch::Tensor output_indices = torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type())); torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid( const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);
(static_cast<indptr_t>(edge_count_aligned) + BLOCK_SIZE - 1) /
BLOCK_SIZE); // Find the smallest integer type to store the coo_aligned_rows tensor.
const int num_bits = cuda::NumberOfBits(num_nodes);
std::array<int, 4> type_bits = {8, 15, 31, 63};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong};
auto coo_dtype = types[type_index];
// Perform the actual copying, of the indices array into auto coo_aligned_rows = ExpandIndptrImpl(
// output_indices in an aligned manner. output_indptr_aligned, coo_dtype, torch::nullopt, edge_count_aligned);
CUDA_KERNEL_CALL(
_CopyIndicesAlignedKernel, grid, block, 0, AT_DISPATCH_INTEGRAL_TYPES(
static_cast<indptr_t>(edge_count_aligned), num_nodes, sliced_indptr, coo_dtype, "UVAIndexSelectCSCCopyIndicesCOO", ([&] {
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get(), using coo_rows_t = scalar_t;
reinterpret_cast<indices_t*>(indices.data_ptr()), // Perform the actual copying, of the indices array into
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm); // output_indices in an aligned manner.
CUDA_KERNEL_CALL(
_CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(indices.data_ptr()),
coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
}));
return {output_indptr, output_indices}; return {output_indptr, output_indices};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment