Unverified Commit 26e740a3 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refactoring SliceCSCIndptr for Neighbor Sampling. (#6799)

parent ceef30b4
...@@ -24,6 +24,21 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input); ...@@ -24,6 +24,21 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input);
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values.
*
* @param indptr The indptr tensor.
* @param nodes The nodes to read from indptr
*
* @return Tuple of tensors with values:
* (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees
* tensor (first one) has size nodes.size(0) + 1 so that calling ExclusiveCumSum
* on it gives the output indptr.
*/
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes);
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
......
...@@ -102,26 +102,34 @@ struct PairSum { ...@@ -102,26 +102,34 @@ struct PairSum {
}; };
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes]) // Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
template <typename indptr_t> std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
auto SliceCSCIndptr( torch::Tensor indptr, torch::Tensor nodes) {
const indptr_t* const indptr, torch::Tensor nodes, cudaStream_t stream) {
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream); const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(cuda::GetCurrentStream());
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// Read indptr only once in case it is pinned and access is slow. // Read indptr only once in case it is pinned and access is slow.
auto sliced_indptr = allocator.AllocateStorage<indptr_t>(num_nodes); auto sliced_indptr =
torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type()));
// compute in-degrees // compute in-degrees
auto in_degree = allocator.AllocateStorage<indptr_t>(num_nodes + 1); auto in_degree =
torch::empty(num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
AT_DISPATCH_INDEX_TYPES(nodes.scalar_type(), "IndexSelectCSCNodes", ([&] { AT_DISPATCH_INTEGRAL_TYPES(
using nodes_t = index_t; indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
thrust::for_each( using indptr_t = scalar_t;
exec_policy, iota, iota + num_nodes, AT_DISPATCH_INDEX_TYPES(
SliceFunc<indptr_t, nodes_t>{ nodes.scalar_type(), "IndexSelectCSCNodes", ([&] {
nodes.data_ptr<nodes_t>(), indptr, using nodes_t = index_t;
in_degree.get(), sliced_indptr.get()}); thrust::for_each(
})); exec_policy, iota, iota + num_nodes,
return std::make_pair(std::move(in_degree), std::move(sliced_indptr)); SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>()});
}));
}));
return {in_degree, sliced_indptr};
} }
template <typename indptr_t, typename indices_t> template <typename indptr_t, typename indices_t>
...@@ -198,13 +206,14 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( ...@@ -198,13 +206,14 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
auto stream = c10::cuda::getDefaultCUDAStream(); auto stream = c10::cuda::getDefaultCUDAStream();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES( return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] { indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t; using indptr_t = scalar_t;
auto [in_degree_ptr, sliced_indptr_ptr] = auto in_degree =
SliceCSCIndptr(indptr.data_ptr<indptr_t>(), nodes, stream); std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto in_degree = in_degree_ptr.get(); auto sliced_indptr =
auto sliced_indptr = sliced_indptr_ptr.get(); std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES( return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] { indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>( return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
...@@ -265,13 +274,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( ...@@ -265,13 +274,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
auto stream = c10::cuda::getDefaultCUDAStream(); auto stream = c10::cuda::getDefaultCUDAStream();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES( return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] { indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t; using indptr_t = scalar_t;
auto [in_degree_ptr, sliced_indptr_ptr] = auto in_degree =
SliceCSCIndptr(indptr.data_ptr<indptr_t>(), nodes, stream); std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto in_degree = in_degree_ptr.get(); auto sliced_indptr =
auto sliced_indptr = sliced_indptr_ptr.get(); std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
// Output indptr for the slice indexed by nodes. // Output indptr for the slice indexed by nodes.
torch::Tensor output_indptr = torch::empty( torch::Tensor output_indptr = torch::empty(
num_nodes + 1, nodes.options().dtype(indptr.scalar_type())); num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment