Unverified Commit 498188dd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] exclude nothing if edge is not found in node pairs (#6807)

parent b04c9797
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
from typing import Dict, Tuple, Union
......@@ -181,6 +182,10 @@ class SampledSubgraph:
index = {}
is_cscformat = 0
for etype, pair in self.node_pairs.items():
if etype not in edges:
# No edges need to be excluded.
index[etype] = None
continue
src_type, _, dst_type = etype_str_to_tuple(etype)
original_row_node_ids = (
None
......@@ -207,7 +212,7 @@ class SampledSubgraph:
)
index[etype] = _exclude_homo_edges(
reverse_edges,
edges.get(etype),
edges[etype],
assume_num_node_within_int32,
)
if is_cscformat:
......@@ -266,8 +271,12 @@ def _relabel_two_arrays(lhs_array, rhs_array):
return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :]
def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32):
"""Return the indices of edges that are not in edges_to_exclude."""
def _exclude_homo_edges(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: Tuple[torch.Tensor, torch.Tensor],
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
val = edges[0] << 32 | edges[1]
val_to_exclude = edges_to_exclude[0] << 32 | edges_to_exclude[1]
......@@ -286,6 +295,8 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
def _index_select(obj, index):
if obj is None:
return None
if index is None:
return obj
if isinstance(obj, torch.Tensor):
return obj[index]
if isinstance(obj, tuple):
......@@ -312,6 +323,8 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
def _index_select(obj, index):
if obj is None:
return None
if index is None:
return obj
if isinstance(obj, CSCFormatBase):
new_indices = obj.indices[index]
new_indptr = torch.searchsorted(index, obj.indptr)
......
from functools import partial
import dgl
import dgl.graphbolt as gb
import pytest
......@@ -88,25 +90,27 @@ def to_link_batch(data):
def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5
datapipe = Sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
datapipe = Sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
def get_hetero_graph():
......@@ -163,12 +167,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5
datapipe = Sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
......@@ -187,13 +192,14 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
datapipe = Sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment