"...text-generation-inference.git" did not exist on "72ab60fdd588266be85ff469eeb07b6c42e2f56a"
Unverified Commit aad12df6 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `gb.expand_indptr` (#6871)

parent 78fa316a
...@@ -187,6 +187,7 @@ Utilities ...@@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str etype_tuple_to_str
isin isin
seed seed
expand_indptr
add_reverse_edges add_reverse_edges
exclude_seed_edges exclude_seed_edges
compact_csc_format compact_csc_format
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* @file graphbolt/cuda_ops.h * @file graphbolt/cuda_ops.h
* @brief Available CUDA operations in Graphbolt. * @brief Available CUDA operations in Graphbolt.
*/ */
#ifndef GRAPHBOLT_CUDA_OPS_H_
#define GRAPHBOLT_CUDA_OPS_H_
#include <torch/script.h> #include <torch/script.h>
...@@ -162,16 +164,22 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input); ...@@ -162,16 +164,22 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input);
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index); torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);
/** /**
* @brief CSRToCOO implements conversion from a given indptr offset tensor to a * @brief ExpandIndptrImpl implements conversion from a given indptr offset
* COO format tensor including ids in [0, indptr.size(0) - 1). * tensor to a COO format tensor. If node_ids is not given, it is assumed to be
* equal to torch::arange(indptr.size(0) - 1, dtype=dtype).
* *
* @param input A tensor containing IDs. * @param indptr The indptr offset tensor.
* @param output_dtype Dtype of output. * @param dtype The dtype of the returned output tensor.
* @param node_ids Optional 1D tensor represents the node ids.
* @param output_size Optional value of indptr[-1]. Passing it eliminates CPU
* GPU synchronization.
* *
* @return * @return The resulting tensor.
* - The resulting tensor with output_dtype.
*/ */
torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype); torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> node_ids = torch::nullopt,
torch::optional<int64_t> output_size = torch::nullopt);
/** /**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and * @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
...@@ -214,3 +222,5 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -214,3 +222,5 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_CUDA_OPS_H_
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* @file graphbolt/cuda_sampling_ops.h * @file graphbolt/cuda_sampling_ops.h
* @brief Available CUDA sampling operations in Graphbolt. * @brief Available CUDA sampling operations in Graphbolt.
*/ */
#ifndef GRAPHBOLT_CUDA_SAMPLING_OPS_H_
#define GRAPHBOLT_CUDA_SAMPLING_OPS_H_
#include <graphbolt/fused_sampled_subgraph.h> #include <graphbolt/fused_sampled_subgraph.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -65,3 +67,5 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph( ...@@ -65,3 +67,5 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_CUDA_SAMPLING_OPS_H_
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/csr_to_coo.cu * @file cuda/expand_indptr.cu
* @brief CSRToCOO operator implementation on CUDA. * @brief ExpandIndptr operator implementation on CUDA.
*/ */
#include <thrust/iterator/constant_iterator.h> #include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
namespace graphbolt { namespace graphbolt {
namespace ops { namespace ops {
template <typename indices_t> template <typename indices_t, typename nodes_t>
struct RepeatIndex { struct RepeatIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) { __host__ __device__ auto operator()(indices_t i) {
return thrust::make_constant_iterator(i); return thrust::make_constant_iterator(nodes ? nodes[i] : i);
} }
}; };
...@@ -38,42 +39,59 @@ struct AdjacentDifference { ...@@ -38,42 +39,59 @@ struct AdjacentDifference {
} }
}; };
torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) { torch::Tensor ExpandIndptrImpl(
const auto num_rows = indptr.size(0) - 1; torch::Tensor indptr, torch::ScalarType dtype,
thrust::counting_iterator<int64_t> iota(0); torch::optional<torch::Tensor> nodes,
torch::optional<int64_t> output_size) {
if (!output_size.has_value()) {
output_size = AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
auto indptr_ptr = indptr.data_ptr<scalar_t>();
auto output_size = cuda::CopyScalar{indptr_ptr + indptr.size(0) - 1};
return static_cast<scalar_t>(output_size);
}));
}
auto csc_rows =
torch::empty(output_size.value(), indptr.options().dtype(dtype));
return AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "CSRToCOOIndptr", ([&] { indptr.scalar_type(), "ExpandIndptrIndptr", ([&] {
using indptr_t = scalar_t; using indptr_t = scalar_t;
auto indptr_ptr = indptr.data_ptr<indptr_t>(); auto indptr_ptr = indptr.data_ptr<indptr_t>();
auto num_edges =
cuda::CopyScalar{indptr.data_ptr<indptr_t>() + num_rows};
auto csr_rows = torch::empty(
static_cast<indptr_t>(num_edges),
indptr.options().dtype(output_dtype));
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
output_dtype, "CSRToCOOIndices", ([&] { dtype, "ExpandIndptrIndices", ([&] {
using indices_t = scalar_t; using indices_t = scalar_t;
auto csc_rows_ptr = csr_rows.data_ptr<indices_t>(); auto csc_rows_ptr = csc_rows.data_ptr<indices_t>();
auto nodes_dtype = nodes ? nodes.value().scalar_type() : dtype;
AT_DISPATCH_INTEGRAL_TYPES(
nodes_dtype, "ExpandIndptrNodes", ([&] {
using nodes_t = scalar_t;
auto nodes_ptr =
nodes ? nodes.value().data_ptr<nodes_t>() : nullptr;
auto input_buffer = thrust::make_transform_iterator( thrust::counting_iterator<int64_t> iota(0);
iota, RepeatIndex<indices_t>{}); auto input_buffer = thrust::make_transform_iterator(
auto output_buffer = thrust::make_transform_iterator( iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});
iota, OutputBufferIndexer<indptr_t, indices_t>{ auto output_buffer = thrust::make_transform_iterator(
indptr_ptr, csc_rows_ptr}); iota, OutputBufferIndexer<indptr_t, indices_t>{
auto buffer_sizes = thrust::make_transform_iterator( indptr_ptr, csc_rows_ptr});
iota, AdjacentDifference<indptr_t>{indptr_ptr}); auto buffer_sizes = thrust::make_transform_iterator(
iota, AdjacentDifference<indptr_t>{indptr_ptr});
constexpr int64_t max_copy_at_once = const auto num_rows = indptr.size(0) - 1;
std::numeric_limits<int32_t>::max(); constexpr int64_t max_copy_at_once =
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { std::numeric_limits<int32_t>::max();
CUB_CALL( for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
DeviceCopy::Batched, input_buffer + i, output_buffer + i, CUB_CALL(
buffer_sizes + i, std::min(num_rows - i, max_copy_at_once)); DeviceCopy::Batched, input_buffer + i,
} output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once));
}
}));
})); }));
return csr_rows;
})); }));
return csc_rows;
} }
} // namespace ops } // namespace ops
......
...@@ -26,7 +26,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph( ...@@ -26,7 +26,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
in_degree, sliced_indptr, type_per_edge.value(), nodes, in_degree, sliced_indptr, type_per_edge.value(), nodes,
indptr.size(0) - 2, num_edges)); indptr.size(0) - 2, num_edges));
} }
auto rows = CSRToCOO(output_indptr, indices.scalar_type()); auto rows = ExpandIndptrImpl(
output_indptr, indices.scalar_type(), torch::nullopt, num_edges);
auto i = torch::arange(output_indices.size(0), output_indptr.options()); auto i = torch::arange(output_indices.size(0), output_indptr.options());
auto edge_ids = auto edge_ids =
i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows); i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);
......
...@@ -157,6 +157,15 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -157,6 +157,15 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
auto max_in_degree = torch::empty(
1,
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
CUB_CALL(
DeviceReduce::Max, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows);
}));
torch::optional<int64_t> num_edges_; torch::optional<int64_t> num_edges_;
torch::Tensor sub_indptr; torch::Tensor sub_indptr;
torch::optional<torch::Tensor> sliced_probs_or_mask; torch::optional<torch::Tensor> sliced_probs_or_mask;
...@@ -182,16 +191,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -182,16 +191,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
if (!probs_or_mask.has_value() && fanouts.size() <= 1) { if (!probs_or_mask.has_value() && fanouts.size() <= 1) {
sub_indptr = ExclusiveCumSum(in_degree); sub_indptr = ExclusiveCumSum(in_degree);
} }
auto max_in_degree = torch::empty( auto coo_rows = ExpandIndptrImpl(
1, sub_indptr, indices.scalar_type(), torch::nullopt, num_edges_);
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsInDegree", ([&] {
CUB_CALL(
DeviceReduce::Max, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows);
}));
auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type());
const auto num_edges = coo_rows.size(0); const auto num_edges = coo_rows.size(0);
const auto random_seed = RandomEngine::ThreadLocal()->RandInt( const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
...@@ -233,7 +234,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -233,7 +234,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows}; cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};
// Find the smallest integer type to store the edge id offsets. // Find the smallest integer type to store the edge id offsets.
// CSRToCOO had synch inside, so it is safe to read max_in_degree now. // ExpandIndptr or IndexSelectCSCImpl had synch inside, so it is safe to
// read max_in_degree now.
const int num_bits = const int num_bits =
cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]); cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
std::array<int, 4> type_bits = {8, 16, 32, 64}; std::array<int, 4> type_bits = {8, 16, 32, 64};
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file expand_indptr.cc
* @brief ExpandIndptr operators.
*/
#include <graphbolt/cuda_ops.h>
#include "./macro.h"
#include "./utils.h"
namespace graphbolt {
namespace ops {
torch::Tensor ExpandIndptr(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> node_ids,
torch::optional<int64_t> output_size) {
if (utils::is_on_gpu(indptr) &&
(!node_ids.has_value() || utils::is_on_gpu(node_ids.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "ExpandIndptr", {
return ExpandIndptrImpl(indptr, dtype, node_ids, output_size);
});
}
if (!node_ids.has_value()) {
node_ids = torch::arange(indptr.size(0) - 1, indptr.options().dtype(dtype));
}
return node_ids.value().to(dtype).repeat_interleave(
indptr.diff(), 0, output_size);
}
} // namespace ops
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file expand_indptr.h
* @brief ExpandIndptr operators.
*/
#ifndef GRAPHBOLT_EXPAND_INDPTR_H_
#define GRAPHBOLT_EXPAND_INDPTR_H_
#include <torch/script.h>
namespace graphbolt {
namespace ops {
/**
* @brief ExpandIndptr implements conversion from a given indptr offset
* tensor to a COO format tensor. If node_ids is not given, it is assumed to be
* equal to torch::arange(indptr.size(0) - 1, dtype=dtype).
*
* @param indptr The indptr offset tensor.
* @param dtype The dtype of the returned output tensor.
* @param node_ids 1D tensor represents the node ids.
* @param output_size Optional, value of indptr[-1]. Passing it eliminates CPU
* GPU synchronization.
*
* @return The resulting tensor.
*/
torch::Tensor ExpandIndptr(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> node_ids = torch::nullopt,
torch::optional<int64_t> output_size = torch::nullopt);
} // namespace ops
} // namespace graphbolt
#endif // GRAPHBOLT_EXPAND_INDPTR_H_
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#ifdef GRAPHBOLT_USE_CUDA #ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/max_uva_threads.h" #include "./cuda/max_uva_threads.h"
#endif #endif
#include "./expand_indptr.h"
#include "./index_select.h" #include "./index_select.h"
#include "./random.h" #include "./random.h"
...@@ -87,6 +88,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -87,6 +88,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("isin", &IsIn); m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect); m.def("index_select", &ops::IndexSelect);
m.def("index_select_csc", &ops::IndexSelectCSC); m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("expand_indptr", &ops::ExpandIndptr);
m.def("set_seed", &RandomEngine::SetManualSeed); m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA #ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads); m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
......
...@@ -15,6 +15,7 @@ __all__ = [ ...@@ -15,6 +15,7 @@ __all__ = [
"etype_tuple_to_str", "etype_tuple_to_str",
"CopyTo", "CopyTo",
"isin", "isin",
"expand_indptr",
"CSCFormatBase", "CSCFormatBase",
"seed", "seed",
] ]
...@@ -56,6 +57,51 @@ def isin(elements, test_elements): ...@@ -56,6 +57,51 @@ def isin(elements, test_elements):
return torch.ops.graphbolt.isin(elements, test_elements) return torch.ops.graphbolt.isin(elements, test_elements)
def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
"""Converts a given indptr offset tensor to a COO format tensor. If
node_ids is not given, it is assumed to be equal to
torch.arange(indptr.size(0) - 1, dtype=dtype, device=indptr.device).
This is equivalent to
.. code:: python
if node_ids is None:
node_ids = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)
return node_ids.to(dtype).repeat_interleave(indptr.diff())
Parameters
----------
indptr : torch.Tensor
A 1D tensor represents the csc_indptr tensor.
dtype : Optional[torch.dtype]
The dtype of the returned output tensor.
node_ids : Optional[torch.Tensor]
A 1D tensor represents the column node ids that the returned tensor will
be populated with.
output_size : Optional[int]
The size of the output tensor. Should be equal to indptr[-1]. Using this
argument avoids a stream synchronization to calculate the output shape.
Returns
-------
torch.Tensor
The converted COO tensor with values from node_ids.
"""
assert indptr.dim() == 1, "Indptr should be 1D tensor."
assert not (
node_ids is None and dtype is None
), "One of node_ids or dtype must be given."
assert (
node_ids is None or node_ids.dim() == 1
), "Node_ids should be 1D tensor."
if dtype is None:
dtype = node_ids.dtype
return torch.ops.graphbolt.expand_indptr(
indptr, dtype, node_ids, output_size
)
def etype_tuple_to_str(c_etype): def etype_tuple_to_str(c_etype):
"""Convert canonical etype from tuple to string. """Convert canonical etype from tuple to string.
......
...@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from ..base import CSCFormatBase, etype_str_to_tuple from ..base import CSCFormatBase, etype_str_to_tuple, expand_indptr
def unique_and_compact( def unique_and_compact(
...@@ -240,9 +240,9 @@ def unique_and_compact_csc_formats( ...@@ -240,9 +240,9 @@ def unique_and_compact_csc_formats(
def _broadcast_timestamps(csc, dst_timestamps): def _broadcast_timestamps(csc, dst_timestamps):
"""Broadcast the timestamp of each destination node to its corresponding """Broadcast the timestamp of each destination node to its corresponding
source nodes.""" source nodes."""
count = torch.diff(csc.indptr) return expand_indptr(
src_timestamps = torch.repeat_interleave(dst_timestamps, count) csc.indptr, node_ids=dst_timestamps, output_size=len(csc.indices)
return src_timestamps )
def compact_csc_format( def compact_csc_format(
......
...@@ -7,7 +7,13 @@ import torch ...@@ -7,7 +7,13 @@ import torch
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import apply_to, CSCFormatBase, etype_str_to_tuple, isin from .base import (
apply_to,
CSCFormatBase,
etype_str_to_tuple,
expand_indptr,
isin,
)
__all__ = ["SampledSubgraph"] __all__ = ["SampledSubgraph"]
...@@ -226,10 +232,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): ...@@ -226,10 +232,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indices = torch.index_select( indices = torch.index_select(
original_row_node_ids, dim=0, index=indices original_row_node_ids, dim=0, index=indices
) )
if original_column_node_ids is not None: indptr = expand_indptr(
indptr = original_column_node_ids.repeat_interleave(indptr.diff()) indptr, indices.dtype, original_column_node_ids, len(indices)
else: )
indptr = torch.arange(len(indptr) - 1).repeat_interleave(indptr.diff())
return (indices, indptr) return (indices, indptr)
......
...@@ -248,6 +248,25 @@ def test_isin_non_1D_dim(): ...@@ -248,6 +248,25 @@ def test_isin_non_1D_dim():
gb.isin(elements, test_elements) gb.isin(elements, test_elements)
def torch_expand_indptr(indptr, dtype, nodes=None):
if nodes is None:
nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)
return nodes.to(dtype).repeat_interleave(indptr.diff())
@pytest.mark.parametrize("nodes", [None, True])
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_expand_indptr(nodes, dtype):
if nodes:
nodes = torch.tensor([1, 7, 3, 4, 5, 8], dtype=dtype, device=F.ctx())
indptr = torch.tensor([0, 2, 2, 7, 10, 12, 20], device=F.ctx())
torch_result = torch_expand_indptr(indptr, dtype, nodes)
gb_result = gb.expand_indptr(indptr, dtype, nodes)
assert torch.equal(torch_result, gb_result)
gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item())
assert torch.equal(torch_result, gb_result)
def test_csc_format_base_representation(): def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase( csc_format_base = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]), indptr=torch.tensor([0, 2, 4]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment