Unverified Commit eb434893 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix incorrect indptr of in_subgraph (#6555)

parent b35757a0
......@@ -186,33 +186,36 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
torch::Tensor indptr = torch::zeros_like(indptr_);
const size_t num_seeds = nodes.size(0);
const auto num_seeds = nodes.size(0);
torch::Tensor indptr = torch::zeros({num_seeds + 1}, indptr_.dtype());
std::vector<torch::Tensor> indices_arr(num_seeds);
torch::Tensor original_column_node_ids =
torch::zeros({num_seeds}, indptr_.dtype());
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const int64_t node_id = nodes[i].item<int64_t>();
const int64_t start_idx = indptr_[node_id].item<int64_t>();
const int64_t end_idx = indptr_[node_id + 1].item<int64_t>();
indptr[node_id + 1] = end_idx - start_idx;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(start_idx, end_idx);
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
const auto& nonzero_idx = torch::nonzero(indptr).reshape(-1);
torch::Tensor compact_indptr =
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype());
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx}));
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "InSubgraph", ([&] {
torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const auto node_id = nodes[i].item<scalar_t>();
const auto start_idx = indptr_[node_id].item<scalar_t>();
const auto end_idx = indptr_[node_id + 1].item<scalar_t>();
indptr[i + 1] = end_idx - start_idx;
original_column_node_ids[i] = node_id;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(start_idx, end_idx);
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
}));
return c10::make_intrusive<FusedSampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1,
indptr.cumsum(0), torch::cat(indices_arr), original_column_node_ids,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
......
......@@ -289,7 +289,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
edges of the given nodes.
edges of the given nodes. Subgraph is compacted according to the order
of passed-in `nodes`.
Parameters
----------
......
......@@ -497,20 +497,20 @@ def test_in_subgraph_homogeneous():
graph = gb.from_fused_csc(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4])
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs[0], torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
in_subgraph.node_pairs[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs[1], torch.LongTensor([1, 1, 3, 3, 4, 4, 4])
in_subgraph.node_pairs[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
)
......@@ -564,7 +564,7 @@ def test_in_subgraph_heterogeneous():
# Extract in subgraph.
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([1, 2]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes)
......@@ -576,10 +576,10 @@ def test_in_subgraph_heterogeneous():
in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([1, 0])
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([1, 2])
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([2, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1])
......@@ -588,15 +588,15 @@ def test_in_subgraph_heterogeneous():
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([0, 1, 2])
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([1, 2, 2])
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment