Unverified Commit 0b5abba8 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify the return type of `CSCSamplingGraph.in_subgraph()` (#6517)

parent 4c6e6543
......@@ -305,8 +305,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert len(torch.unique(nodes)) == len(
nodes
), "Nodes cannot have duplicate values."
# TODO: change the result to 'FusedSampledSubgraphImpl'.
return self._c_csc_graph.in_subgraph(nodes)
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_sampled_subgraph(
self,
......
......@@ -501,18 +501,17 @@ def test_in_subgraph_homogeneous():
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
in_subgraph.node_pairs[0], torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal(
in_subgraph.original_row_node_ids, torch.arange(0, total_num_nodes)
in_subgraph.node_pairs[1], torch.LongTensor([1, 1, 3, 3, 4, 4, 4])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert in_subgraph.type_per_edge is None
@unittest.skipIf(
......@@ -567,19 +566,34 @@ def test_in_subgraph_heterogeneous():
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
in_subgraph.node_pairs["N0:R0:N0"][0], torch.LongTensor([])
)
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal(
in_subgraph.original_row_node_ids, torch.arange(0, total_num_nodes)
in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([1, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([1, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([0, 1, 2])
)
assert torch.equal(
in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3])
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([1, 2, 2])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment