"notebooks/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "24b7468fcce5d00d205d2b3e12a43a0a181b85e5"
Unverified Commit 4440ac75 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `exclude_seed_edges` to support seeds. (#7114)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 391f513e
...@@ -23,7 +23,7 @@ from .internal import ( ...@@ -23,7 +23,7 @@ from .internal import (
unique_and_compact, unique_and_compact,
unique_and_compact_csc_formats, unique_and_compact_csc_formats,
) )
from .utils import add_reverse_edges, exclude_seed_edges from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
def load_graphbolt(): def load_graphbolt():
......
...@@ -116,7 +116,9 @@ class SampledSubgraph: ...@@ -116,7 +116,9 @@ class SampledSubgraph:
self, self,
edges: Union[ edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Dict[str, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
], ],
assume_num_node_within_int32: bool = True, assume_num_node_within_int32: bool = True,
): ):
...@@ -183,7 +185,7 @@ class SampledSubgraph: ...@@ -183,7 +185,7 @@ class SampledSubgraph:
), "Values > int32 are not supported yet." ), "Values > int32 are not supported yet."
assert ( assert (
isinstance(self.sampled_csc, (CSCFormatBase, tuple)) isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), ( ) == isinstance(edges, (tuple, torch.Tensor)), (
"The sampled subgraph and the edges to exclude should be both " "The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous." "homogeneous or both heterogeneous."
) )
...@@ -200,9 +202,14 @@ class SampledSubgraph: ...@@ -200,9 +202,14 @@ class SampledSubgraph:
self.original_row_node_ids, self.original_row_node_ids,
self.original_column_node_ids, self.original_column_node_ids,
) )
index = _exclude_homo_edges( if isinstance(edges, torch.Tensor):
reverse_edges, edges, assume_num_node_within_int32 index = _exclude_homo_edges_2(
) reverse_edges, edges, assume_num_node_within_int32
)
else:
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
else: else:
index = {} index = {}
...@@ -227,11 +234,18 @@ class SampledSubgraph: ...@@ -227,11 +234,18 @@ class SampledSubgraph:
original_row_node_ids, original_row_node_ids,
original_column_node_ids, original_column_node_ids,
) )
index[etype] = _exclude_homo_edges( if isinstance(edges[etype], torch.Tensor):
reverse_edges, index[etype] = _exclude_homo_edges_2(
edges[etype], reverse_edges,
assume_num_node_within_int32, edges[etype],
) assume_num_node_within_int32,
)
else:
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
...@@ -289,6 +303,27 @@ def _exclude_homo_edges( ...@@ -289,6 +303,27 @@ def _exclude_homo_edges(
return torch.nonzero(mask, as_tuple=True)[0] return torch.nonzero(mask, as_tuple=True)[0]
def _exclude_homo_edges_2(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: torch.Tensor,
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
val = edges[0] << 32 | edges[1]
edges_to_exclude_trans = edges_to_exclude.T
val_to_exclude = (
edges_to_exclude_trans[0] << 32 | edges_to_exclude_trans[1]
)
else:
# TODO: Add support for value > int32.
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index.""" """Slice the subgraph according to the index."""
......
...@@ -73,6 +73,74 @@ def add_reverse_edges( ...@@ -73,6 +73,74 @@ def add_reverse_edges(
return combined_edges return combined_edges
def add_reverse_edges_2(
edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, torch.Tensor], torch.Tensor]
- If sampled subgraph is homogeneous, then `edges` should be a N*2
tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, torch.Tensor], torch.Tensor]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": torch.tensor([[0, 1],[1, 2]]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': torch.tensor([[0, 1],[1, 2]]),
'B:rr:A': torch.tensor([[1, 0],[2, 1]])}
>>> edges = torch.tensor([[0, 1],[1, 2]])
>>> print(gb.add_reverse_edges(edges))
torch.tensor([[1, 0],[2, 1]])
"""
if isinstance(edges, torch.Tensor):
assert edges.ndim == 2 and edges.shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges.shape}."
)
reverse_edges = edges.flip(dims=(1,))
return torch.cat((edges, reverse_edges))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
assert edges[etype].ndim == 2 and edges[etype].shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges[etype].shape}."
)
if reverse_etype in combined_edges:
combined_edges[reverse_etype] = torch.cat(
(
combined_edges[reverse_etype],
edges[etype].flip(dims=(1,)),
)
)
else:
combined_edges[reverse_etype] = edges[etype].flip(dims=(1,))
return combined_edges
def exclude_seed_edges( def exclude_seed_edges(
minibatch: MiniBatch, minibatch: MiniBatch,
include_reverse_edges: bool = False, include_reverse_edges: bool = False,
...@@ -89,11 +157,18 @@ def exclude_seed_edges( ...@@ -89,11 +157,18 @@ def exclude_seed_edges(
reverse_etypes_mapping : Dict[str, str] = None reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types. The mapping from the original edge types to their reverse edge types.
""" """
edges_to_exclude = minibatch.node_pairs if minibatch.node_pairs is not None:
if include_reverse_edges: edges_to_exclude = minibatch.node_pairs
edges_to_exclude = add_reverse_edges( if include_reverse_edges:
minibatch.node_pairs, reverse_etypes_mapping edges_to_exclude = add_reverse_edges(
) minibatch.node_pairs, reverse_etypes_mapping
)
else:
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges_2(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [ minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude) subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs for subgraph in minibatch.sampled_subgraphs
......
...@@ -265,6 +265,238 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column): ...@@ -265,6 +265,238 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
_assert_container_equal(result.original_edge_ids, expected_edge_ids) _assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_deduplicated_tensor(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9])
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_duplicated_tensor(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 3, 3, 5]),
indices=torch.tensor([0, 3, 3, 2, 2]),
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([24])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([3])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([11])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([2])
original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_deduplicated_tensor(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([2, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": torch.cat((src_to_exclude, dst_to_exclude))
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1]),
indices=torch.tensor([1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_duplicated_tensor(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 5]),
indices=torch.tensor([2, 2, 1, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl(
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": torch.cat((src_to_exclude, dst_to_exclude))
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 2, 2]),
indices=torch.tensor([1, 1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "cpu", F._default_context_str == "cpu",
reason="`to` function needs GPU to test.", reason="`to` function needs GPU to test.",
......
...@@ -1057,8 +1057,7 @@ def test_SubgraphSampler_Link(sampler_type): ...@@ -1057,8 +1057,7 @@ def test_SubgraphSampler_Link(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
for data in datapipe: for data in datapipe:
assert torch.equal( assert torch.equal(
...@@ -1091,8 +1090,7 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type): ...@@ -1091,8 +1090,7 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type):
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -1171,8 +1169,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type): ...@@ -1171,8 +1169,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
for data in datapipe: for data in datapipe:
for compacted_seeds in data.compacted_seeds.values(): for compacted_seeds in data.compacted_seeds.values():
...@@ -1228,8 +1225,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): ...@@ -1228,8 +1225,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -1274,8 +1270,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type): ...@@ -1274,8 +1270,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -1321,8 +1316,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): ...@@ -1321,8 +1316,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now. datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
......
import re
import unittest
from functools import partial
import backend as F
import dgl
import dgl.graphbolt as gb
import pytest
import torch
def test_add_reverse_edges_homo():
edges = (torch.tensor([0, 1, 2, 3]), torch.tensor([4, 5, 6, 7]))
combined_edges = gb.add_reverse_edges(edges)
assert torch.equal(
combined_edges[0], torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
)
assert torch.equal(
combined_edges[1], torch.tensor([4, 5, 6, 7, 0, 1, 2, 3])
)
def test_add_reverse_edges_hetero():
# reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6]))}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][0], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][1], torch.tensor([0, 1, 2]))
# reverse_etype exists in original etypes.
edges = {
"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6])),
"n2:e2:n1": (torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])),
}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(
combined_edges["n2:e2:n1"][0], torch.tensor([7, 8, 9, 4, 5, 6])
)
assert torch.equal(
combined_edges["n2:e2:n1"][1], torch.tensor([10, 11, 12, 0, 1, 2])
)
def test_add_reverse_edges_2_homo():
edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T
combined_edges = gb.add_reverse_edges_2(edges)
assert torch.equal(
combined_edges,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T,
)
# Tensor with uncorrect dimensions.
edges = torch.tensor([0, 1, 2, 3])
with pytest.raises(
AssertionError,
match=re.escape(
"Only tensor with shape N*2 is supported now, but got torch.Size([4])."
),
):
gb.add_reverse_edges_2(edges)
def test_add_reverse_edges_2_hetero():
# reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping)
assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
)
assert torch.equal(
combined_edges["n2:e2:n1"], torch.tensor([[4, 5, 6], [0, 1, 2]]).T
)
# reverse_etype exists in original etypes.
edges = {
"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T,
"n2:e2:n1": torch.tensor([[7, 8, 9], [10, 11, 12]]).T,
}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping)
assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
)
assert torch.equal(
combined_edges["n2:e2:n1"],
torch.tensor([[7, 8, 9, 4, 5, 6], [10, 11, 12, 0, 1, 2]]).T,
)
# Tensor with uncorrect dimensions.
edges = {
"n1:e1:n2": torch.tensor([0, 1, 2]),
"n2:e2:n1": torch.tensor([7, 8, 9]),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Only tensor with shape N*2 is supported now, but got torch.Size([3])."
),
):
gb.add_reverse_edges_2(edges, reverse_etype_mapping)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
def test_exclude_seed_edges_homo_cpu():
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
names = "seeds"
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = gb.NeighborSampler
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([3, 4, 4, 5, 6]).to(F.ctx()),
torch.tensor([3, 4, 4]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
def test_exclude_seed_edges_gpu():
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
names = "seeds"
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
sampler = gb.NeighborSampler
datapipe = sampler(
datapipe,
graph,
fanouts,
deduplicate=True,
)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([4, 3, 5, 5]).to(F.ctx()),
torch.tensor([4, 3]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 2, 4, 4]).to(F.ctx()),
torch.tensor([0, 1, 2, 2]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
(sampled_subgraph.sampled_csc.indices), compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
def get_hetero_graph():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
def test_exclude_seed_edges_hetero():
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{"n1:e1:n2": gb.ItemSet(torch.tensor([[0, 1]]), names="seeds")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5]),
indices=torch.tensor([1, 0, 1, 0, 1]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([1, 2, 1, 0]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 1]),
indices=torch.tensor([1]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 2], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0]),
"n2": torch.tensor([1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
torch.sort(sampled_subgraph.original_row_node_ids[ntype])[
0
],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
torch.sort(
sampled_subgraph.original_column_node_ids[ntype]
)[0],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment