Unverified Commit 26e740a3 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refactoring SliceCSCIndptr for Neighbor Sampling. (#6799)

parent ceef30b4
......@@ -24,6 +24,21 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input);
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values.
*
* @param indptr The indptr tensor.
* @param nodes The nodes to read from indptr
*
* @return Tuple of tensors with values:
* (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees
* tensor (first one) has size nodes.size(0) + 1 so that calling ExclusiveCumSum
* on it gives the output indptr.
*/
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes);
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
......
......@@ -102,26 +102,34 @@ struct PairSum {
};
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
template <typename indptr_t>
auto SliceCSCIndptr(
const indptr_t* const indptr, torch::Tensor nodes, cudaStream_t stream) {
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes) {
auto allocator = cuda::GetAllocator();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(cuda::GetCurrentStream());
const int64_t num_nodes = nodes.size(0);
// Read indptr only once in case it is pinned and access is slow.
auto sliced_indptr = allocator.AllocateStorage<indptr_t>(num_nodes);
auto sliced_indptr =
torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type()));
// compute in-degrees
auto in_degree = allocator.AllocateStorage<indptr_t>(num_nodes + 1);
auto in_degree =
torch::empty(num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
thrust::counting_iterator<int64_t> iota(0);
AT_DISPATCH_INDEX_TYPES(nodes.scalar_type(), "IndexSelectCSCNodes", ([&] {
using nodes_t = index_t;
thrust::for_each(
exec_policy, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr,
in_degree.get(), sliced_indptr.get()});
}));
return std::make_pair(std::move(in_degree), std::move(sliced_indptr));
AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "IndexSelectCSCNodes", ([&] {
using nodes_t = index_t;
thrust::for_each(
exec_policy, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>()});
}));
}));
return {in_degree, sliced_indptr};
}
template <typename indptr_t, typename indices_t>
......@@ -198,13 +206,14 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
auto stream = c10::cuda::getDefaultCUDAStream();
const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
auto [in_degree_ptr, sliced_indptr_ptr] =
SliceCSCIndptr(indptr.data_ptr<indptr_t>(), nodes, stream);
auto in_degree = in_degree_ptr.get();
auto sliced_indptr = sliced_indptr_ptr.get();
auto in_degree =
std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto sliced_indptr =
std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
......@@ -265,13 +274,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
auto stream = c10::cuda::getDefaultCUDAStream();
const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
auto [in_degree_ptr, sliced_indptr_ptr] =
SliceCSCIndptr(indptr.data_ptr<indptr_t>(), nodes, stream);
auto in_degree = in_degree_ptr.get();
auto sliced_indptr = sliced_indptr_ptr.get();
auto in_degree =
std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto sliced_indptr =
std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
// Output indptr for the slice indexed by nodes.
torch::Tensor output_indptr = torch::empty(
num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment