"docker/install/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e9440acb06e1ab7a9d2f6f888a092b115da3155c"
Unverified Commit 096bbf96 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt][CUDA] Optimize UVA index select (#6632)

parent a28f1f9a
......@@ -3,6 +3,7 @@
* @file cuda/index_select_impl.cu
* @brief Index select operator implementation on CUDA.
*/
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
......@@ -101,14 +102,16 @@ template <typename DType, typename IdType>
torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const int64_t input_len = input.size(0);
const int64_t return_len = index.size(0);
const int64_t feature_size = std::accumulate(
input.sizes().begin() + 1, input.sizes().end(), 1, std::multiplies<>());
const int64_t original_feature_size = std::accumulate(
input.sizes().begin() + 1, input.sizes().end(), 1ll, std::multiplies<>());
const auto aligned_feature_size =
input.element_size() * original_feature_size / sizeof(DType);
torch::Tensor ret = torch::empty(
{return_len, feature_size}, torch::TensorOptions()
.dtype(input.dtype())
.device(c10::DeviceType::CUDA));
DType* input_ptr = input.data_ptr<DType>();
DType* ret_ptr = ret.data_ptr<DType>();
{return_len, original_feature_size}, torch::TensorOptions()
.dtype(input.dtype())
.device(c10::DeviceType::CUDA));
DType* input_ptr = reinterpret_cast<DType*>(input.data_ptr());
DType* ret_ptr = reinterpret_cast<DType*>(ret.data_ptr());
// Sort the index to improve the memory access pattern.
torch::Tensor sorted_index, permutation;
......@@ -118,7 +121,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
cudaStream_t stream = torch::cuda::getDefaultCUDAStream();
if (feature_size == 1) {
if (aligned_feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads.
const int num_threads = cuda::FindNumThreads(return_len);
const int num_blocks = (return_len + num_threads - 1) / num_threads;
......@@ -127,23 +130,24 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else {
dim3 block(512, 1);
while (static_cast<int64_t>(block.x) >= 2 * feature_size) {
while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
block.x >>= 1;
block.y <<= 1;
}
const dim3 grid((return_len + block.y - 1) / block.y);
if (feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
// When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient.
CUDA_KERNEL_CALL(
IndexSelectMultiKernel, grid, block, 0, stream, input_ptr, input_len,
feature_size, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,
permutation_ptr);
} else {
// Use aligned version to improve the memory access pattern.
CUDA_KERNEL_CALL(
IndexSelectMultiKernelAligned, grid, block, 0, stream, input_ptr,
input_len, feature_size, index_sorted_ptr, return_len, ret_ptr,
permutation_ptr);
input_len, aligned_feature_size, index_sorted_ptr, return_len,
ret_ptr, permutation_ptr);
}
}
......@@ -157,18 +161,40 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
/**
* @brief UVA index select operator implementation on CUDA.
*
* The supporting input types are: float, double, int, int64_t.
* All basic torch types are supported for input.
* The supporting index types are: int, int64_t.
*/
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index) {
return AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Int, at::ScalarType::Long, input.scalar_type(),
"UVAIndexSelectImpl", [&] {
return AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "UVAIndexSelectImpl", [&] {
return UVAIndexSelectImpl_<scalar_t, index_t>(input, index);
});
});
return AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "UVAIndexSelectImpl", ([&] {
const auto ptr = (size_t)input.data_ptr();
const int64_t feature_size = std::accumulate(
input.sizes().begin() + 1, input.sizes().end(), 1ll,
std::multiplies<>());
// We perform the copy with datatype of size powers of 2, and the
// maximum data type we use has 16 bytes. We check the alignment of the
// pointer and the feature dimensionality to determine the largest
// type to use for the copy to minimize the number of CUDA threads used.
// Alignment denotes the maximum suitable alignment and datatype size
// for the copies.
const int aligned_access_size =
std::gcd(16, std::gcd(ptr, input.element_size() * feature_size));
switch (aligned_access_size) {
case 1:
return UVAIndexSelectImpl_<uint8_t, index_t>(input, index);
case 2:
return UVAIndexSelectImpl_<uint16_t, index_t>(input, index);
case 4:
return UVAIndexSelectImpl_<uint32_t, index_t>(input, index);
case 8:
return UVAIndexSelectImpl_<uint64_t, index_t>(input, index);
case 16:
return UVAIndexSelectImpl_<float4, index_t>(input, index);
default:
TORCH_CHECK(false, "UVAIndexSelectImpl: Unreachable code path!");
return torch::Tensor{};
}
}));
}
} // namespace ops
......
......@@ -141,12 +141,27 @@ def test_torch_based_feature(in_memory):
reason="Tests for pinned memory are only meaningful on GPU.",
)
@pytest.mark.parametrize(
"dtype", [torch.float32, torch.float64, torch.int32, torch.int64]
"dtype",
[
torch.float32,
torch.float64,
torch.int32,
torch.int64,
torch.int8,
torch.float16,
torch.complex128,
],
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("shape", [(2, 1), (2, 3), (2, 2, 2)])
@pytest.mark.parametrize("shape", [(2, 1), (2, 3), (2, 2, 2), (137, 13, 3)])
def test_torch_based_pinned_feature(dtype, idtype, shape):
tensor = torch.arange(0, reduce(mul, shape), dtype=dtype).reshape(shape)
if dtype == torch.complex128:
tensor = torch.complex(
torch.randint(0, 13, shape, dtype=torch.float64),
torch.randint(0, 13, shape, dtype=torch.float64),
)
else:
tensor = torch.randint(0, 13, shape, dtype=dtype)
test_tensor = tensor.clone().detach()
test_tensor_cuda = test_tensor.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment