Unverified Commit 01df9bad authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `InSubgraph` (#6830)

parent fa1ae3b7
......@@ -72,9 +72,7 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if all input tensors are on device memory.
* NOTE: The shape of all tensors must be 1-D.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
......@@ -85,23 +83,6 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if indices tensor is on pinned memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values.
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file graphbolt/cuda_sampling_ops.h
* @brief Available CUDA sampling operations in Graphbolt.
*/
#include <graphbolt/fused_sampled_subgraph.h>
#include <torch/script.h>
namespace graphbolt {
namespace ops {
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
*
* @return FusedSampledSubgraph.
*/
c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<torch::Tensor> type_per_edge);
} // namespace ops
} // namespace graphbolt
......@@ -267,7 +267,7 @@ void IndexSelectCSCCopyIndices(
}
}
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
auto stream = cuda::GetCurrentStream();
const int64_t num_nodes = nodes.size(0);
......@@ -315,5 +315,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
}));
}
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
if (indices.is_pinned()) {
return UVAIndexSelectCSCImpl(indptr, indices, nodes);
} else {
return DeviceIndexSelectCSCImpl(indptr, indices, nodes);
}
}
} // namespace ops
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/insubgraph.cu
* @brief InSubgraph operator implementation on CUDA.
*/
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <cub/cub.cuh>
#include "./common.h"
namespace graphbolt {
namespace ops {
c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<torch::Tensor> type_per_edge) {
auto [output_indptr, output_indices] =
IndexSelectCSCImpl(indptr, indices, nodes);
torch::optional<torch::Tensor> output_type_per_edge;
if (type_per_edge) {
output_type_per_edge =
std::get<1>(IndexSelectCSCImpl(indptr, type_per_edge.value(), nodes));
}
auto rows = CSRToCOO(output_indptr, indices.scalar_type());
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
auto i = torch::arange(output_indices.size(0), output_indptr.options());
auto edge_ids =
i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);
return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, edge_ids,
output_type_per_edge);
}
} // namespace ops
} // namespace graphbolt
......@@ -4,6 +4,7 @@
* @brief Source file of sampling graph.
*/
#include <graphbolt/cuda_sampling_ops.h>
#include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/serialize.h>
#include <torch/torch.h>
......@@ -16,6 +17,7 @@
#include <tuple>
#include <vector>
#include "./macro.h"
#include "./random.h"
#include "./shared_memory_helper.h"
#include "./utils.h"
......@@ -272,6 +274,15 @@ FusedCSCSamplingGraph::GetState() const {
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);
});
}
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
const auto num_seeds = nodes.size(0);
......
......@@ -26,17 +26,11 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (indices.is_pinned() && utils::is_accessible_from_gpu(indptr) &&
if (utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) &&
utils::is_accessible_from_gpu(nodes)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelectCSC",
{ return UVAIndexSelectCSCImpl(indptr, indices, nodes); });
} else if (
indices.device().type() == c10::DeviceType::CUDA &&
utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(nodes)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "nodesSelectCSC",
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); });
}
// @todo: The CPU supports only integer dtypes for indices tensor.
......
......@@ -439,11 +439,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
(num.get(etype_id, 0),),
dtype=indices.dtype,
device=indices.device,
)
if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=original_edge_ids.dtype
(num.get(etype_id, 0),),
dtype=original_edge_ids.dtype,
device=original_edge_ids.device,
)
subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0
......@@ -486,7 +490,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_edge_ids = original_hetero_edge_ids
sampled_csc = {
etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]),
indptr=torch.tensor(
subgraph_indptr[etype], device=indptr.device
),
indices=subgraph_indice[etype],
)
for etype in self.edge_type_to_id.keys()
......
......@@ -73,12 +73,14 @@ def test_InSubgraphSampler_homo():
"""
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
graph = gb.fused_csc_sampling_graph(indptr, indices)
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
......@@ -92,25 +94,27 @@ def test_InSubgraphSampler_homo():
return _indices
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]))
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 3])
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 3]).to(F.ctx()),
)
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]))
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
)
assert torch.equal(original_indices(mn), torch.tensor([1, 4]))
assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]))
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
)
assert torch.equal(original_indices(mn), torch.tensor([1, 2]))
assert torch.equal(original_indices(mn), torch.tensor([1, 2]).to(F.ctx()))
def test_InSubgraphSampler_hetero():
......@@ -149,7 +153,7 @@ def test_InSubgraphSampler_hetero():
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
).to(F.ctx())
item_set = gb.ItemSetDict(
{
......@@ -158,14 +162,18 @@ def test_InSubgraphSampler_hetero():
}
)
batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
it = iter(in_subgraph_sampler)
mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0]))
assert torch.equal(
mn.seed_nodes["N0"], torch.LongTensor([1, 0]).to(F.ctx())
)
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
......@@ -182,13 +190,17 @@ def test_InSubgraphSampler_hetero():
),
}
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
assert torch.equal(
pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
mn = next(it)
assert mn.seed_nodes == {
"N0": torch.LongTensor([2]),
"N1": torch.LongTensor([0]),
"N0": torch.LongTensor([2]).to(F.ctx()),
"N1": torch.LongTensor([0]).to(F.ctx()),
}
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
......@@ -205,11 +217,17 @@ def test_InSubgraphSampler_hetero():
),
}
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
assert torch.equal(
pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1]))
assert torch.equal(
mn.seed_nodes["N1"], torch.LongTensor([2, 1]).to(F.ctx())
)
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
......@@ -225,6 +243,14 @@ def test_InSubgraphSampler_hetero():
indices=torch.LongTensor([1, 2, 0]),
),
}
if graph.csc_indptr.is_cuda:
expected_sampled_csc["N0:R1:N1"] = gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([1, 0])
)
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
assert torch.equal(
pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment