Unverified Commit 0b5abba8 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify the return type of `CSCSamplingGraph.in_subgraph()` (#6517)

parent 4c6e6543
...@@ -305,8 +305,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -305,8 +305,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert len(torch.unique(nodes)) == len( assert len(torch.unique(nodes)) == len(
nodes nodes
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
# TODO: change the result to 'FusedSampledSubgraphImpl'. _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._c_csc_graph.in_subgraph(nodes) return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_sampled_subgraph( def _convert_to_sampled_subgraph(
self, self,
......
...@@ -501,18 +501,17 @@ def test_in_subgraph_homogeneous(): ...@@ -501,18 +501,17 @@ def test_in_subgraph_homogeneous():
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal( assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) in_subgraph.node_pairs[0], torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
) )
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal( assert torch.equal(
in_subgraph.original_row_node_ids, torch.arange(0, total_num_nodes) in_subgraph.node_pairs[1], torch.LongTensor([1, 1, 3, 3, 4, 4, 4])
) )
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal( assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
) )
assert in_subgraph.type_per_edge is None
@unittest.skipIf( @unittest.skipIf(
...@@ -567,19 +566,34 @@ def test_in_subgraph_heterogeneous(): ...@@ -567,19 +566,34 @@ def test_in_subgraph_heterogeneous():
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal( assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) in_subgraph.node_pairs["N0:R0:N0"][0], torch.LongTensor([])
) )
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal( assert torch.equal(
in_subgraph.original_row_node_ids, torch.arange(0, total_num_nodes) in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([])
) )
assert torch.equal( assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([1, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([1, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([0, 1, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3]) in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([1, 2, 2])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment