Unverified Commit 1b5b7de5 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Support UVA index select for TorchBasedFeature (#6417)

parent 75804a7d
......@@ -4,6 +4,7 @@ set (CMAKE_CXX_STANDARD 17)
if(USE_CUDA)
enable_language(CUDA)
add_definitions(-DGRAPHBOLT_USE_CUDA)
endif()
# Find PyTorch cmake files and PyTorch versions with the python interpreter
......@@ -56,6 +57,11 @@ target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${BOLT_DIR}
"../third_party/pcg/include")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")
# TODO: upgrade to 17 for consistency with CXX standard once our linux CI supports it.
if(USE_CUDA)
set_target_properties(${LIB_GRAPHBOLT_NAME} PROPERTIES CUDA_STANDARD 14)
endif()
# The Torch CMake configuration only sets up the path for the MKL library when
# using the conda distribution. The following is a workaround to address this
# when using a standalone installation of MKL.
......
/**
* Copyright (c) 2023 by Contributors
* @file cuda/index_select_impl.cu
* @brief Index select operator implementation on CUDA.
*/
#include <c10/cuda/CUDAException.h>
#include <torch/script.h>
#include <numeric>
#include "../index_select.h"
namespace graphbolt {
namespace ops {
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
const DType* const input, const int64_t input_len,
const int64_t feature_size, const IdType* const index,
const int64_t output_len, DType* const output) {
int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t stride = blockDim.y * gridDim.x;
while (out_row_index < output_len) {
int64_t column = threadIdx.x;
const int64_t in_row = index[out_row_index];
assert(in_row >= 0 && in_row < input_len);
while (column < feature_size) {
output[out_row_index * feature_size + column] =
input[in_row * feature_size + column];
column += blockDim.x;
}
out_row_index += stride;
}
}
template <typename DType, typename IdType>
torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const int64_t input_len = input.size(0);
const int64_t return_len = index.size(0);
const int64_t feature_size = std::accumulate(
input.sizes().begin() + 1, input.sizes().end(), 1, std::multiplies<>());
torch::Tensor ret = torch::empty(
{return_len, feature_size}, torch::TensorOptions()
.dtype(input.dtype())
.device(c10::DeviceType::CUDA));
DType* input_ptr = input.data_ptr<DType>();
IdType* index_ptr = index.data_ptr<IdType>();
DType* ret_ptr = ret.data_ptr<DType>();
cudaStream_t stream = 0;
dim3 block(512, 1);
// Find the smallest block size that can fit the feature_size.
while (static_cast<int64_t>(block.x) >= 2 * feature_size) {
block.x >>= 1;
block.y <<= 1;
}
const dim3 grid((return_len + block.y - 1) / block.y);
IndexSelectMultiKernel<<<grid, block, 0, stream>>>(
input_ptr, input_len, feature_size, index_ptr, return_len, ret_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto return_shape = std::vector<int64_t>({return_len});
return_shape.insert(
return_shape.end(), input.sizes().begin() + 1, input.sizes().end());
ret = ret.reshape(return_shape);
return ret;
}
/**
* @brief UVA index select operator implementation on CUDA.
*
* The supporting input types are: float, double, int, int64_t.
* The supporting index types are: int, int64_t.
*/
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index) {
return AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Int, at::ScalarType::Long, input.scalar_type(),
"UVAIndexSelectImpl", [&] {
return AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "UVAIndexSelectImpl", [&] {
return UVAIndexSelectImpl_<scalar_t, index_t>(input, index);
});
});
}
} // namespace ops
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* @file index_select.cc
* @brief Index select operators.
*/
#include "./index_select.h"
#include "./macro.h"
namespace graphbolt {
namespace ops {
torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
if (input.is_pinned() &&
(index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); });
}
return input.index({index.to(torch::kLong)});
}
} // namespace ops
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* @file index_select.h
* @brief Index select operators.
*/
#ifndef GRAPHBOLT_INDEX_SELECT_H_
#define GRAPHBOLT_INDEX_SELECT_H_
#include <torch/script.h>
namespace graphbolt {
namespace ops {
/** @brief Implemented in the cuda directory. */
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);
/**
* @brief Select rows from input tensor according to index tensor.
*
* NOTE:
* 1. The shape of input tensor can be multi-dimensional, but the index tensor
* must be 1-D.
* 2. If input is on pinned memory and index is on pinned memory or GPU memory,
* then UVAIndexSelectImpl will be called. Otherwise, torch::index_select will
* be called.
*
* @param input Input tensor with shape (N, ...).
* @param index Index tensor with shape (M,).
* @return torch::Tensor Output tensor with shape (M, ...).
*/
torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index);
} // namespace ops
} // namespace graphbolt
#endif // GRAPHBOLT_INDEX_SELECT_H_
/**
* Copyright (c) 2023 by Contributors
* @file macro.h
* @brief Graphbolt macros.
*/
#ifndef GRAPHBOLT_MACRO_H_
#define GRAPHBOLT_MACRO_H_
#include <torch/script.h>
namespace graphbolt {
// Dispatch operator implementation function to CUDA device only.
#ifdef GRAPHBOLT_USE_CUDA
#define GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(device_type, name, ...) \
if (device_type == c10::DeviceType::CUDA) { \
const auto XPU = c10::DeviceType::CUDA; \
__VA_ARGS__ \
} else { \
TORCH_CHECK(false, name, " is only available on CUDA device."); \
}
#else
#define GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(device_type, name, ...) \
TORCH_CHECK(false, name, " is only available on CUDA device.");
#endif
} // namespace graphbolt
#endif // GRAPHBOLT_MACRO_H_
......@@ -9,6 +9,8 @@
#include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
#include "./index_select.h"
namespace graphbolt {
namespace sampling {
......@@ -63,6 +65,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect);
}
} // namespace sampling
......
......@@ -57,6 +57,14 @@ class TorchBasedFeature(Feature):
[3, 4]])
>>> feature.read(torch.tensor([0]))
tensor([[1, 2]])
3. Pinned CPU feature.
>>> torch_feat = torch.arange(10).reshape(2, -1).pin_memory()
>>> feature = gb.TorchBasedFeature(torch_feat)
>>> feature.read().device
device(type='cuda', index=0)
>>> feature.read(torch.tensor([0]).cuda()).device
device(type='cuda', index=0)
"""
def __init__(self, torch_feature: torch.Tensor):
......@@ -75,8 +83,9 @@ class TorchBasedFeature(Feature):
def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
The returned tensor is always in memory, no matter whether the feature
store is in memory or on disk.
If the feature is on pinned CPU memory and `ids` is on GPU or pinned CPU
memory, it will be read by GPU and the returned tensor will be on GPU.
Otherwise, the returned tensor will be on CPU.
Parameters
----------
......@@ -90,8 +99,10 @@ class TorchBasedFeature(Feature):
The read feature.
"""
if ids is None:
if self._tensor.is_pinned():
return self._tensor.cuda()
return self._tensor
return self._tensor[ids]
return torch.ops.graphbolt.index_select(self._tensor, ids)
def size(self):
"""Get the size of the feature.
......@@ -133,6 +144,11 @@ class TorchBasedFeature(Feature):
f"The size of the feature is {self.size()}, "
f"while the size of the value is {value.size()[1:]}."
)
if self._tensor.is_pinned() and value.is_cuda and ids.is_cuda:
raise NotImplementedError(
"Update the feature on pinned CPU memory by GPU is not "
"supported yet."
)
self._tensor[ids] = value
......
import os
import tempfile
import unittest
import backend as F
import numpy as np
import pydantic
......@@ -125,6 +128,54 @@ def test_torch_based_feature(in_memory):
feature_a = feature_b = None
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Tests for pinned memory are only meaningful on GPU.",
)
@pytest.mark.parametrize(
"dtype", [torch.float32, torch.float64, torch.int32, torch.int64]
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
def test_torch_based_pinned_feature(dtype, idtype):
a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype).pin_memory()
b = torch.tensor(
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype
).pin_memory()
c = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype).pin_memory()
feature_a = gb.TorchBasedFeature(a)
feature_b = gb.TorchBasedFeature(b)
feature_c = gb.TorchBasedFeature(c)
assert torch.equal(
feature_a.read(),
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype).cuda(),
)
assert feature_a.read().is_cuda
assert torch.equal(
feature_b.read(),
torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype).cuda(),
)
assert feature_b.read().is_cuda
assert torch.equal(
feature_a.read(torch.tensor([0], dtype=idtype).cuda()),
torch.tensor([[1, 2, 3]], dtype=dtype).cuda(),
)
assert feature_a.read(torch.tensor([0], dtype=idtype).cuda()).is_cuda
assert torch.equal(
feature_b.read(torch.tensor([1], dtype=idtype).cuda()),
torch.tensor([[[4, 5], [6, 7]]], dtype=dtype).cuda(),
)
assert feature_b.read(torch.tensor([1], dtype=idtype).cuda()).is_cuda
assert feature_c.read().is_cuda
assert torch.equal(
feature_c.read(torch.tensor([0], dtype=idtype)),
torch.tensor([[1, 2, 3]], dtype=dtype),
)
assert not feature_c.read(torch.tensor([0], dtype=idtype)).is_cuda
def write_tensor_to_disk(dir, name, t, fmt="torch"):
if fmt == "torch":
torch.save(t, os.path.join(dir, name + ".pt"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment