Unverified Commit eb434893 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix incorrect indptr of in_subgraph (#6555)

parent b35757a0
...@@ -186,33 +186,36 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph( ...@@ -186,33 +186,36 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
using namespace torch::indexing; using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100; const int32_t kDefaultGrainSize = 100;
torch::Tensor indptr = torch::zeros_like(indptr_); const auto num_seeds = nodes.size(0);
const size_t num_seeds = nodes.size(0); torch::Tensor indptr = torch::zeros({num_seeds + 1}, indptr_.dtype());
std::vector<torch::Tensor> indices_arr(num_seeds); std::vector<torch::Tensor> indices_arr(num_seeds);
torch::Tensor original_column_node_ids =
torch::zeros({num_seeds}, indptr_.dtype());
std::vector<torch::Tensor> edge_ids_arr(num_seeds); std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds); std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const int64_t node_id = nodes[i].item<int64_t>();
const int64_t start_idx = indptr_[node_id].item<int64_t>();
const int64_t end_idx = indptr_[node_id + 1].item<int64_t>();
indptr[node_id + 1] = end_idx - start_idx;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(start_idx, end_idx);
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
const auto& nonzero_idx = torch::nonzero(indptr).reshape(-1); AT_DISPATCH_INTEGRAL_TYPES(
torch::Tensor compact_indptr = indptr_.scalar_type(), "InSubgraph", ([&] {
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype()); torch::parallel_for(
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx})); 0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const auto node_id = nodes[i].item<scalar_t>();
const auto start_idx = indptr_[node_id].item<scalar_t>();
const auto end_idx = indptr_[node_id + 1].item<scalar_t>();
indptr[i + 1] = end_idx - start_idx;
original_column_node_ids[i] = node_id;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(start_idx, end_idx);
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
}));
return c10::make_intrusive<FusedSampledSubgraph>( return c10::make_intrusive<FusedSampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1, indptr.cumsum(0), torch::cat(indices_arr), original_column_node_ids,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr), torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_ type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)} ? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
......
...@@ -289,7 +289,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -289,7 +289,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming An in subgraph is equivalent to creating a new graph using the incoming
edges of the given nodes. edges of the given nodes. Subgraph is compacted according to the order
of passed-in `nodes`.
Parameters Parameters
---------- ----------
......
...@@ -497,20 +497,20 @@ def test_in_subgraph_homogeneous(): ...@@ -497,20 +497,20 @@ def test_in_subgraph_homogeneous():
graph = gb.from_fused_csc(indptr, indices) graph = gb.from_fused_csc(indptr, indices)
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs[0], torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) in_subgraph.node_pairs[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs[1], torch.LongTensor([1, 1, 3, 3, 4, 4, 4]) in_subgraph.node_pairs[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
assert torch.equal( assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
) )
...@@ -564,7 +564,7 @@ def test_in_subgraph_heterogeneous(): ...@@ -564,7 +564,7 @@ def test_in_subgraph_heterogeneous():
# Extract in subgraph. # Extract in subgraph.
nodes = { nodes = {
"N0": torch.LongTensor([1]), "N0": torch.LongTensor([1]),
"N1": torch.LongTensor([1, 2]), "N1": torch.LongTensor([2, 1]),
} }
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes)
...@@ -576,10 +576,10 @@ def test_in_subgraph_heterogeneous(): ...@@ -576,10 +576,10 @@ def test_in_subgraph_heterogeneous():
in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([]) in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([1, 0]) in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([0, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([1, 2]) in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([2, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1]) in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1])
...@@ -588,15 +588,15 @@ def test_in_subgraph_heterogeneous(): ...@@ -588,15 +588,15 @@ def test_in_subgraph_heterogeneous():
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1]) in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([0, 1, 2]) in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([1, 2, 2]) in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
assert torch.equal( assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment