Unverified Commit 498188dd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] exclude nothing if edge is not found in node pairs (#6807)

parent b04c9797
"""Graphbolt sampled subgraph.""" """Graphbolt sampled subgraph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
...@@ -181,6 +182,10 @@ class SampledSubgraph: ...@@ -181,6 +182,10 @@ class SampledSubgraph:
index = {} index = {}
is_cscformat = 0 is_cscformat = 0
for etype, pair in self.node_pairs.items(): for etype, pair in self.node_pairs.items():
if etype not in edges:
# No edges need to be excluded.
index[etype] = None
continue
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
original_row_node_ids = ( original_row_node_ids = (
None None
...@@ -207,7 +212,7 @@ class SampledSubgraph: ...@@ -207,7 +212,7 @@ class SampledSubgraph:
) )
index[etype] = _exclude_homo_edges( index[etype] = _exclude_homo_edges(
reverse_edges, reverse_edges,
edges.get(etype), edges[etype],
assume_num_node_within_int32, assume_num_node_within_int32,
) )
if is_cscformat: if is_cscformat:
...@@ -266,8 +271,12 @@ def _relabel_two_arrays(lhs_array, rhs_array): ...@@ -266,8 +271,12 @@ def _relabel_two_arrays(lhs_array, rhs_array):
return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :] return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :]
def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32): def _exclude_homo_edges(
"""Return the indices of edges that are not in edges_to_exclude.""" edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: Tuple[torch.Tensor, torch.Tensor],
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32: if assume_num_node_within_int32:
val = edges[0] << 32 | edges[1] val = edges[0] << 32 | edges[1]
val_to_exclude = edges_to_exclude[0] << 32 | edges_to_exclude[1] val_to_exclude = edges_to_exclude[0] << 32 | edges_to_exclude[1]
...@@ -286,6 +295,8 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -286,6 +295,8 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
def _index_select(obj, index): def _index_select(obj, index):
if obj is None: if obj is None:
return None return None
if index is None:
return obj
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return obj[index] return obj[index]
if isinstance(obj, tuple): if isinstance(obj, tuple):
...@@ -312,6 +323,8 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -312,6 +323,8 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
def _index_select(obj, index): def _index_select(obj, index):
if obj is None: if obj is None:
return None return None
if index is None:
return obj
if isinstance(obj, CSCFormatBase): if isinstance(obj, CSCFormatBase):
new_indices = obj.indices[index] new_indices = obj.indices[index]
new_indptr = torch.searchsorted(index, obj.indptr) new_indptr = torch.searchsorted(index, obj.indptr)
......
from functools import partial
import dgl import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
...@@ -88,25 +90,27 @@ def to_link_batch(data): ...@@ -88,25 +90,27 @@ def to_link_batch(data):
def test_SubgraphSampler_Link(labor): def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2) datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(item_sampler, graph, fanouts) datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(neighbor_dp)) == 5 datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(labor): def test_SubgraphSampler_Link_With_Negative(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2) datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(neighbor_dp)) == 5 datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
def get_hetero_graph(): def get_hetero_graph():
...@@ -163,12 +167,13 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -163,12 +167,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(item_sampler, graph, fanouts) datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(neighbor_dp)) == 5 datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
...@@ -187,13 +192,14 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor): ...@@ -187,13 +192,14 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(neighbor_dp)) == 5 datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment