Unverified Commit 01df9bad authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `InSubgraph` (#6830)

parent fa1ae3b7
...@@ -72,9 +72,7 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements); ...@@ -72,9 +72,7 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
* @brief Select columns for a sparse matrix in a CSC format according to nodes * @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor. * tensor.
* *
* NOTE: * NOTE: The shape of all tensors must be 1-D.
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if all input tensors are on device memory.
* *
* @param indptr Indptr tensor containing offsets with shape (N,). * @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],). * @param indices Indices tensor with edge information of shape (indptr[N],).
...@@ -85,23 +83,6 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements); ...@@ -85,23 +83,6 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if indices tensor is on pinned memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/** /**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the * @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values. * given nodes and their indptr values.
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file graphbolt/cuda_sampling_ops.h
* @brief Available CUDA sampling operations in Graphbolt.
*/
#include <graphbolt/fused_sampled_subgraph.h>
#include <torch/script.h>
namespace graphbolt {
namespace ops {
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
*
* @return FusedSampledSubgraph.
*/
c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<torch::Tensor> type_per_edge);
} // namespace ops
} // namespace graphbolt
...@@ -267,7 +267,7 @@ void IndexSelectCSCCopyIndices( ...@@ -267,7 +267,7 @@ void IndexSelectCSCCopyIndices(
} }
} }
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
auto stream = cuda::GetCurrentStream(); auto stream = cuda::GetCurrentStream();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
...@@ -315,5 +315,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( ...@@ -315,5 +315,14 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
})); }));
} }
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
if (indices.is_pinned()) {
return UVAIndexSelectCSCImpl(indptr, indices, nodes);
} else {
return DeviceIndexSelectCSCImpl(indptr, indices, nodes);
}
}
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/insubgraph.cu
* @brief InSubgraph operator implementation on CUDA.
*/
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <cub/cub.cuh>
#include "./common.h"
namespace graphbolt {
namespace ops {
c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
torch::optional<torch::Tensor> type_per_edge) {
auto [output_indptr, output_indices] =
IndexSelectCSCImpl(indptr, indices, nodes);
torch::optional<torch::Tensor> output_type_per_edge;
if (type_per_edge) {
output_type_per_edge =
std::get<1>(IndexSelectCSCImpl(indptr, type_per_edge.value(), nodes));
}
auto rows = CSRToCOO(output_indptr, indices.scalar_type());
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
auto i = torch::arange(output_indices.size(0), output_indptr.options());
auto edge_ids =
i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);
return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, edge_ids,
output_type_per_edge);
}
} // namespace ops
} // namespace graphbolt
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief Source file of sampling graph. * @brief Source file of sampling graph.
*/ */
#include <graphbolt/cuda_sampling_ops.h>
#include <graphbolt/fused_csc_sampling_graph.h> #include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <torch/torch.h> #include <torch/torch.h>
...@@ -16,6 +17,7 @@ ...@@ -16,6 +17,7 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "./macro.h"
#include "./random.h" #include "./random.h"
#include "./shared_memory_helper.h" #include "./shared_memory_helper.h"
#include "./utils.h" #include "./utils.h"
...@@ -272,6 +274,15 @@ FusedCSCSamplingGraph::GetState() const { ...@@ -272,6 +274,15 @@ FusedCSCSamplingGraph::GetState() const {
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);
});
}
using namespace torch::indexing; using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100; const int32_t kDefaultGrainSize = 100;
const auto num_seeds = nodes.size(0); const auto num_seeds = nodes.size(0);
......
...@@ -26,17 +26,11 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC( ...@@ -26,17 +26,11 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK( TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors"); indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (indices.is_pinned() && utils::is_accessible_from_gpu(indptr) && if (utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) &&
utils::is_accessible_from_gpu(nodes)) { utils::is_accessible_from_gpu(nodes)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelectCSC", c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return UVAIndexSelectCSCImpl(indptr, indices, nodes); });
} else if (
indices.device().type() == c10::DeviceType::CUDA &&
utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(nodes)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "nodesSelectCSC",
{ return IndexSelectCSCImpl(indptr, indices, nodes); }); { return IndexSelectCSCImpl(indptr, indices, nodes); });
} }
// @todo: The CPU supports only integer dtypes for indices tensor. // @todo: The CPU supports only integer dtypes for indices tensor.
......
...@@ -439,11 +439,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -439,11 +439,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty( subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype (num.get(etype_id, 0),),
dtype=indices.dtype,
device=indices.device,
) )
if has_original_eids: if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty( original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=original_edge_ids.dtype (num.get(etype_id, 0),),
dtype=original_edge_ids.dtype,
device=original_edge_ids.device,
) )
subgraph_indptr[etype] = [0] subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0 subgraph_indice_position[etype] = 0
...@@ -486,7 +490,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -486,7 +490,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
sampled_csc = { sampled_csc = {
etype: CSCFormatBase( etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]), indptr=torch.tensor(
subgraph_indptr[etype], device=indptr.device
),
indices=subgraph_indice[etype], indices=subgraph_indice[etype],
) )
for etype in self.edge_type_to_id.keys() for etype in self.edge_type_to_id.keys()
......
...@@ -73,12 +73,14 @@ def test_InSubgraphSampler_homo(): ...@@ -73,12 +73,14 @@ def test_InSubgraphSampler_homo():
""" """
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
graph = gb.fused_csc_sampling_graph(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
seed_nodes = torch.LongTensor([0, 5, 3]) seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
batch_size = 1 batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
...@@ -92,25 +94,27 @@ def test_InSubgraphSampler_homo(): ...@@ -92,25 +94,27 @@ def test_InSubgraphSampler_homo():
return _indices return _indices
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0])) assert torch.equal(mn.seed_nodes, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 3]) mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 3]).to(F.ctx()),
) )
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5])) assert torch.equal(mn.seed_nodes, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2]) mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
) )
assert torch.equal(original_indices(mn), torch.tensor([1, 4])) assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3])) assert torch.equal(mn.seed_nodes, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2]) mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
) )
assert torch.equal(original_indices(mn), torch.tensor([1, 2])) assert torch.equal(original_indices(mn), torch.tensor([1, 2]).to(F.ctx()))
def test_InSubgraphSampler_hetero(): def test_InSubgraphSampler_hetero():
...@@ -149,7 +153,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -149,7 +153,7 @@ def test_InSubgraphSampler_hetero():
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=ntypes, node_type_to_id=ntypes,
edge_type_to_id=etypes, edge_type_to_id=etypes,
) ).to(F.ctx())
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
...@@ -158,14 +162,18 @@ def test_InSubgraphSampler_hetero(): ...@@ -158,14 +162,18 @@ def test_InSubgraphSampler_hetero():
} }
) )
batch_size = 2 batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
it = iter(in_subgraph_sampler) it = iter(in_subgraph_sampler)
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0])) assert torch.equal(
mn.seed_nodes["N0"], torch.LongTensor([1, 0]).to(F.ctx())
)
expected_sampled_csc = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]), indptr=torch.LongTensor([0, 1, 3]),
...@@ -182,13 +190,17 @@ def test_InSubgraphSampler_hetero(): ...@@ -182,13 +190,17 @@ def test_InSubgraphSampler_hetero():
), ),
} }
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices) assert torch.equal(
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr) pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
mn = next(it) mn = next(it)
assert mn.seed_nodes == { assert mn.seed_nodes == {
"N0": torch.LongTensor([2]), "N0": torch.LongTensor([2]).to(F.ctx()),
"N1": torch.LongTensor([0]), "N1": torch.LongTensor([0]).to(F.ctx()),
} }
expected_sampled_csc = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
...@@ -205,11 +217,17 @@ def test_InSubgraphSampler_hetero(): ...@@ -205,11 +217,17 @@ def test_InSubgraphSampler_hetero():
), ),
} }
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices) assert torch.equal(
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr) pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1])) assert torch.equal(
mn.seed_nodes["N1"], torch.LongTensor([2, 1]).to(F.ctx())
)
expected_sampled_csc = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([]) indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
...@@ -225,6 +243,14 @@ def test_InSubgraphSampler_hetero(): ...@@ -225,6 +243,14 @@ def test_InSubgraphSampler_hetero():
indices=torch.LongTensor([1, 2, 0]), indices=torch.LongTensor([1, 2, 0]),
), ),
} }
if graph.csc_indptr.is_cuda:
expected_sampled_csc["N0:R1:N1"] = gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([1, 0])
)
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices) assert torch.equal(
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr) pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())
)
assert torch.equal(
pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment