Unverified Commit 2a715f2f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] skip empty mask (#6331)

parent 4ea2bd45
...@@ -242,6 +242,8 @@ class CSCSamplingGraph: ...@@ -242,6 +242,8 @@ class CSCSamplingGraph:
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id mask = type_per_edge == etype_id
if mask.count_nonzero() == 0:
continue
hetero_row = row[mask] - self.node_type_offset[src_ntype_id] hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = ( hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id] column[mask] - self.node_type_offset[dst_ntype_id]
......
...@@ -547,7 +547,7 @@ def test_sample_neighbors_hetero(labor): ...@@ -547,7 +547,7 @@ def test_sample_neighbors_hetero(labor):
metadata=metadata, metadata=metadata,
) )
# Generate subgraph via sample neighbors. # Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
...@@ -572,6 +572,27 @@ def test_sample_neighbors_hetero(labor): ...@@ -572,6 +572,27 @@ def test_sample_neighbors_hetero(labor):
assert subgraph.reverse_row_node_ids is None assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None assert subgraph.reverse_edge_ids is None
# Sample on single node type.
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_node_pairs = {
"n2:e2:n1": (
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
),
}
assert len(subgraph.node_pairs) == 1
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert subgraph.reverse_column_node_ids is None
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
...@@ -634,8 +655,14 @@ def test_sample_neighbors_fanouts( ...@@ -634,8 +655,14 @@ def test_sample_neighbors_fanouts(
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1 assert (
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2 expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
)
assert (
expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
)
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment