Unverified Commit 9c36ddcd authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `exclude_edges` to support csc format. (#6648)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 5185c522
......@@ -6,7 +6,7 @@ import torch
from dgl.utils import recursive_apply
from .base import apply_to, etype_str_to_tuple, isin
from .base import apply_to, CSCFormatBase, etype_str_to_tuple, isin
__all__ = ["SampledSubgraph"]
......@@ -144,7 +144,9 @@ class SampledSubgraph:
assert (
assume_num_node_within_int32
), "Values > int32 are not supported yet."
assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), (
assert (
isinstance(self.node_pairs, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
)
......@@ -156,6 +158,16 @@ class SampledSubgraph:
# 2. Exclude the edges and get the index of the edges to keep.
# 3. Slice the subgraph according to the index.
if isinstance(self.node_pairs, tuple):
reverse_edges = _to_reverse_ids_node_pairs(
self.node_pairs,
self.original_row_node_ids,
self.original_column_node_ids,
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph_node_pairs(self, index))
elif isinstance(self.node_pairs, CSCFormatBase):
reverse_edges = _to_reverse_ids(
self.node_pairs,
self.original_row_node_ids,
......@@ -167,6 +179,7 @@ class SampledSubgraph:
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
is_cscformat = 0
for etype, pair in self.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
original_row_node_ids = (
......@@ -179,17 +192,28 @@ class SampledSubgraph:
if self.original_column_node_ids is None
else self.original_column_node_ids.get(dst_type)
)
reverse_edges = _to_reverse_ids(
pair,
original_row_node_ids,
original_column_node_ids,
)
if isinstance(pair, CSCFormatBase):
is_cscformat = 1
reverse_edges = _to_reverse_ids(
pair,
original_row_node_ids,
original_column_node_ids,
)
else:
reverse_edges = _to_reverse_ids_node_pairs(
pair,
original_row_node_ids,
original_column_node_ids,
)
index[etype] = _exclude_homo_edges(
reverse_edges,
edges.get(etype),
assume_num_node_within_int32,
)
return calling_class(*_slice_subgraph(self, index))
if is_cscformat:
return calling_class(*_slice_subgraph(self, index))
else:
return calling_class(*_slice_subgraph_node_pairs(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection."""
......@@ -208,7 +232,9 @@ class SampledSubgraph:
return self
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
def _to_reverse_ids_node_pairs(
node_pair, original_row_node_ids, original_column_node_ids
):
u, v = node_pair
if original_row_node_ids is not None:
u = original_row_node_ids[u]
......@@ -217,6 +243,22 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
return (u, v)
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr
indices = node_pair.indices
if original_row_node_ids is not None:
indices = original_row_node_ids[indices]
if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave(
indptr[1:] - indptr[:-1]
)
else:
indptr = torch.arange(len(indptr) - 1).repeat_interleave(
indptr[1:] - indptr[:-1]
)
return (indices, indptr)
def _relabel_two_arrays(lhs_array, rhs_array):
"""Relabel two arrays into a consecutive range starting from 0."""
concated = torch.cat([lhs_array, rhs_array])
......@@ -238,7 +280,7 @@ def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32):
return torch.nonzero(mask, as_tuple=True)[0]
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index."""
def _index_select(obj, index):
......@@ -262,3 +304,34 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
)
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index."""
def _index_select(obj, index):
if obj is None:
return None
if isinstance(obj, CSCFormatBase):
new_indices = obj.indices[index]
new_indptr = torch.searchsorted(index, obj.indptr)
return CSCFormatBase(
indptr=new_indptr,
indices=new_indices,
)
if isinstance(obj, torch.Tensor):
return obj[index]
# Handle the case when obj is a dictionary.
assert isinstance(obj, dict)
assert isinstance(index, dict)
ret = {}
for k, v in obj.items():
ret[k] = _index_select(v, index[k])
return ret
return (
_index_select(subgraph.node_pairs, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
)
......@@ -20,6 +20,12 @@ def _assert_container_equal(lhs, rhs):
assert len(lhs) == len(rhs)
for l, r in zip(lhs, rhs):
_assert_container_equal(l, r)
elif isinstance(lhs, gb.CSCFormatBase):
assert isinstance(rhs, gb.CSCFormatBase)
assert len(lhs.indptr) == len(rhs.indptr)
assert len(lhs.indices) == len(rhs.indices)
_assert_container_equal(lhs.indptr, rhs.indptr)
_assert_container_equal(lhs.indices, rhs.indices)
elif isinstance(lhs, dict):
assert isinstance(rhs, dict)
assert len(lhs) == len(rhs)
......@@ -30,7 +36,7 @@ def _assert_container_equal(lhs, rhs):
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo(reverse_row, reverse_column):
def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column):
node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2]))
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
......@@ -75,7 +81,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero(reverse_row, reverse_column):
def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
node_pairs = {
"A:relation:B": (
torch.tensor([0, 1, 2]),
......@@ -141,6 +147,240 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 3, 3, 5]),
indices=torch.tensor([0, 3, 3, 2, 2]),
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([24])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([3])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([11])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([2])
original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([2, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1]),
indices=torch.tensor([1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 5]),
indices=torch.tensor([2, 2, 1, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 2, 2]),
indices=torch.tensor([1, 1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment