Unverified Commit 0f3bfd7e authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refactor `IndexSelectCSC` and add `output_size` argument (#6927)

parent 3795a006
......@@ -68,6 +68,27 @@ Sort(torch::Tensor input, int num_bits = 0);
*/
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE: The shape of all tensors must be 1-D.
*
* @param in_degree Indegree tensor containing degrees of nodes being copied.
* @param sliced_indptr Sliced_indptr tensor containing indptr values of nodes
* being copied.
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param nodes_max An upperbound on `nodes.max()`.
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
torch::Tensor nodes, int64_t nodes_max,
torch::optional<int64_t> output_size = torch::nullopt);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
......@@ -77,11 +98,13 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<int64_t> output_size = torch::nullopt);
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
......
......@@ -86,14 +86,15 @@ template <typename indptr_t, typename indices_t>
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::Tensor indices, const int64_t num_nodes,
const indptr_t* const in_degree, const indptr_t* const sliced_indptr,
const int64_t* const perm, torch::TensorOptions nodes_options,
torch::ScalarType indptr_scalar_type) {
const int64_t* const perm, torch::TensorOptions options,
torch::ScalarType indptr_scalar_type,
torch::optional<int64_t> output_size) {
auto allocator = cuda::GetAllocator();
thrust::counting_iterator<int64_t> iota(0);
// Output indptr for the slice indexed by nodes.
auto output_indptr =
torch::empty(num_nodes + 1, nodes_options.dtype(indptr_scalar_type));
torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
auto output_indptr_aligned =
allocator.AllocateStorage<indptr_t>(num_nodes + 1);
......@@ -114,16 +115,18 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
}
// Copy the actual total number of edges.
auto edge_count =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
if (!output_size.has_value()) {
auto edge_count =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
output_size = static_cast<indptr_t>(edge_count);
}
// Copy the modified number of edges.
auto edge_count_aligned =
cuda::CopyScalar{output_indptr_aligned.get() + num_nodes};
// Allocate output array with actual number of edges.
torch::Tensor output_indices = torch::empty(
static_cast<indptr_t>(edge_count),
nodes_options.dtype(indices.scalar_type()));
torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(static_cast<indptr_t>(edge_count_aligned) + BLOCK_SIZE - 1) /
......@@ -141,26 +144,22 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
}
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
torch::Tensor nodes, int num_bits, torch::optional<int64_t> output_size) {
// Sorting nodes so that accesses over PCI-e are more regular.
const auto sorted_idx =
Sort(nodes, cuda::NumberOfBits(indptr.size(0) - 1)).second;
const auto sorted_idx = Sort(nodes, num_bits).second;
const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
sliced_indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
auto in_degree =
std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto sliced_indptr =
std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
indices, num_nodes, in_degree, sliced_indptr,
indices, num_nodes, in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
sorted_idx.data_ptr<int64_t>(), nodes.options(),
indptr.scalar_type());
sliced_indptr.scalar_type(), output_size);
}));
}));
}
......@@ -204,38 +203,39 @@ void IndexSelectCSCCopyIndices(
}
std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
torch::TensorOptions options, torch::optional<int64_t> output_size) {
const int64_t num_nodes = sliced_indptr.size(0);
return AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
sliced_indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
auto in_degree =
std::get<0>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto sliced_indptr =
std::get<1>(in_degree_and_sliced_indptr).data_ptr<indptr_t>();
auto in_degree_ptr = in_degree.data_ptr<indptr_t>();
auto sliced_indptr_ptr = sliced_indptr.data_ptr<indptr_t>();
// Output indptr for the slice indexed by nodes.
torch::Tensor output_indptr = torch::empty(
num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
num_nodes + 1, options.dtype(sliced_indptr.scalar_type()));
// Compute the output indptr, output_indptr.
CUB_CALL(
DeviceScan::ExclusiveSum, in_degree,
DeviceScan::ExclusiveSum, in_degree_ptr,
output_indptr.data_ptr<indptr_t>(), num_nodes + 1);
// Number of edges being copied.
auto edge_count =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
if (!output_size.has_value()) {
auto edge_count =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
output_size = static_cast<indptr_t>(edge_count);
}
// Allocate output array of size number of copied edges.
torch::Tensor output_indices = torch::empty(
static_cast<indptr_t>(edge_count),
nodes.options().dtype(indices.scalar_type()));
output_size.value(), options.dtype(indices.scalar_type()));
GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "IndexSelectCSCCopyIndices", ([&] {
using indices_t = element_size_t;
IndexSelectCSCCopyIndices<indptr_t, indices_t>(
num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()),
sliced_indptr, in_degree, output_indptr.data_ptr<indptr_t>(),
sliced_indptr_ptr, in_degree_ptr,
output_indptr.data_ptr<indptr_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()));
}));
return std::make_tuple(output_indptr, output_indices);
......@@ -243,13 +243,27 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
}
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
torch::Tensor nodes, int64_t nodes_max,
torch::optional<int64_t> output_size) {
if (indices.is_pinned()) {
return UVAIndexSelectCSCImpl(indptr, indices, nodes);
int num_bits = cuda::NumberOfBits(nodes_max + 1);
return UVAIndexSelectCSCImpl(
in_degree, sliced_indptr, indices, nodes, num_bits, output_size);
} else {
return DeviceIndexSelectCSCImpl(indptr, indices, nodes);
return DeviceIndexSelectCSCImpl(
in_degree, sliced_indptr, indices, nodes.options(), output_size);
}
}
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<int64_t> output_size) {
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
return IndexSelectCSCImpl(
in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2,
output_size);
}
} // namespace ops
} // namespace graphbolt
......@@ -16,15 +16,17 @@ namespace ops {
c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<torch::Tensor> type_per_edge) {
auto [output_indptr, output_indices] =
IndexSelectCSCImpl(indptr, indices, nodes);
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
auto [output_indptr, output_indices] = IndexSelectCSCImpl(
in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2);
const int64_t num_edges = output_indices.size(0);
torch::optional<torch::Tensor> output_type_per_edge;
if (type_per_edge) {
output_type_per_edge =
std::get<1>(IndexSelectCSCImpl(indptr, type_per_edge.value(), nodes));
output_type_per_edge = std::get<1>(IndexSelectCSCImpl(
in_degree, sliced_indptr, type_per_edge.value(), nodes,
indptr.size(0) - 2, num_edges));
}
auto rows = CSRToCOO(output_indptr, indices.scalar_type());
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
auto i = torch::arange(output_indices.size(0), output_indptr.options());
auto edge_ids =
i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);
......
......@@ -157,25 +157,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
torch::optional<int64_t> num_edges_;
torch::Tensor sub_indptr;
// @todo mfbalin, refactor IndexSelectCSCImpl so that it does not have to take
// nodes as input
torch::optional<torch::Tensor> sliced_probs_or_mask;
if (probs_or_mask.has_value()) {
torch::Tensor sliced_probs_or_mask_tensor;
std::tie(sub_indptr, sliced_probs_or_mask_tensor) =
IndexSelectCSCImpl(indptr, probs_or_mask.value(), nodes);
std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
in_degree, sliced_indptr, probs_or_mask.value(), nodes,
indptr.size(0) - 2, num_edges_);
sliced_probs_or_mask = sliced_probs_or_mask_tensor;
} else {
sub_indptr = ExclusiveCumSum(in_degree);
num_edges_ = sliced_probs_or_mask_tensor.size(0);
}
if (fanouts.size() > 1) {
torch::Tensor sliced_type_per_edge;
std::tie(sub_indptr, sliced_type_per_edge) =
IndexSelectCSCImpl(indptr, type_per_edge.value(), nodes);
std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
in_degree, sliced_indptr, type_per_edge.value(), nodes,
indptr.size(0) - 2, num_edges_);
std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
num_rows = sliced_indptr.size(0);
num_edges_ = sliced_type_per_edge.size(0);
}
// If sub_indptr was not computed in the two code blocks above:
if (!probs_or_mask.has_value() && fanouts.size() <= 1) {
sub_indptr = ExclusiveCumSum(in_degree);
}
auto max_in_degree = torch::empty(
1,
......
......@@ -22,14 +22,15 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
}
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<int64_t> output_size) {
TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); });
{ return IndexSelectCSCImpl(indptr, indices, nodes, output_size); });
}
// @todo: The CPU supports only integer dtypes for indices tensor.
TORCH_CHECK(
......
......@@ -25,11 +25,13 @@ namespace ops {
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<int64_t> output_size = torch::nullopt);
/**
* @brief Select rows from input tensor according to index tensor.
......
......@@ -22,7 +22,10 @@ from .. import gb_test_utils
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("is_pinned", [False, True])
def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
@pytest.mark.parametrize("output_size", [None, True])
def test_index_select_csc(
indptr_dtype, indices_dtype, idtype, is_pinned, output_size
):
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
......@@ -38,7 +41,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
index = torch.tensor([0, 5, 3], dtype=idtype)
cpu_indptr, cpu_indices = torch.ops.graphbolt.index_select_csc(
indptr, indices, index
indptr, indices, index, None
)
if is_pinned:
indptr = indptr.pin_memory()
......@@ -48,10 +51,12 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
indices = indices.cuda()
index = index.cuda()
if output_size:
output_size = len(cpu_indices)
gpu_indptr, gpu_indices = torch.ops.graphbolt.index_select_csc(
indptr, indices, index
indptr, indices, index, output_size
)
assert not cpu_indptr.is_cuda
assert not cpu_indices.is_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment