"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "b53a9be2ad8dc49ca924a0be2bf9cd1a9979b9bd"
Unverified Commit 9c36ddcd authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `exclude_edges` to support csc format. (#6648)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 5185c522
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import apply_to, etype_str_to_tuple, isin from .base import apply_to, CSCFormatBase, etype_str_to_tuple, isin
__all__ = ["SampledSubgraph"] __all__ = ["SampledSubgraph"]
...@@ -144,7 +144,9 @@ class SampledSubgraph: ...@@ -144,7 +144,9 @@ class SampledSubgraph:
assert ( assert (
assume_num_node_within_int32 assume_num_node_within_int32
), "Values > int32 are not supported yet." ), "Values > int32 are not supported yet."
assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), ( assert (
isinstance(self.node_pairs, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both " "The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous." "homogeneous or both heterogeneous."
) )
...@@ -156,6 +158,16 @@ class SampledSubgraph: ...@@ -156,6 +158,16 @@ class SampledSubgraph:
# 2. Exclude the edges and get the index of the edges to keep. # 2. Exclude the edges and get the index of the edges to keep.
# 3. Slice the subgraph according to the index. # 3. Slice the subgraph according to the index.
if isinstance(self.node_pairs, tuple): if isinstance(self.node_pairs, tuple):
reverse_edges = _to_reverse_ids_node_pairs(
self.node_pairs,
self.original_row_node_ids,
self.original_column_node_ids,
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph_node_pairs(self, index))
elif isinstance(self.node_pairs, CSCFormatBase):
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
self.node_pairs, self.node_pairs,
self.original_row_node_ids, self.original_row_node_ids,
...@@ -167,6 +179,7 @@ class SampledSubgraph: ...@@ -167,6 +179,7 @@ class SampledSubgraph:
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
else: else:
index = {} index = {}
is_cscformat = 0
for etype, pair in self.node_pairs.items(): for etype, pair in self.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
original_row_node_ids = ( original_row_node_ids = (
...@@ -179,17 +192,28 @@ class SampledSubgraph: ...@@ -179,17 +192,28 @@ class SampledSubgraph:
if self.original_column_node_ids is None if self.original_column_node_ids is None
else self.original_column_node_ids.get(dst_type) else self.original_column_node_ids.get(dst_type)
) )
if isinstance(pair, CSCFormatBase):
is_cscformat = 1
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
pair, pair,
original_row_node_ids, original_row_node_ids,
original_column_node_ids, original_column_node_ids,
) )
else:
reverse_edges = _to_reverse_ids_node_pairs(
pair,
original_row_node_ids,
original_column_node_ids,
)
index[etype] = _exclude_homo_edges( index[etype] = _exclude_homo_edges(
reverse_edges, reverse_edges,
edges.get(etype), edges.get(etype),
assume_num_node_within_int32, assume_num_node_within_int32,
) )
if is_cscformat:
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
else:
return calling_class(*_slice_subgraph_node_pairs(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection.""" """Copy `SampledSubgraph` to the specified device using reflection."""
...@@ -208,7 +232,9 @@ class SampledSubgraph: ...@@ -208,7 +232,9 @@ class SampledSubgraph:
return self return self
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): def _to_reverse_ids_node_pairs(
node_pair, original_row_node_ids, original_column_node_ids
):
u, v = node_pair u, v = node_pair
if original_row_node_ids is not None: if original_row_node_ids is not None:
u = original_row_node_ids[u] u = original_row_node_ids[u]
...@@ -217,6 +243,22 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): ...@@ -217,6 +243,22 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
return (u, v) return (u, v)
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr
indices = node_pair.indices
if original_row_node_ids is not None:
indices = original_row_node_ids[indices]
if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave(
indptr[1:] - indptr[:-1]
)
else:
indptr = torch.arange(len(indptr) - 1).repeat_interleave(
indptr[1:] - indptr[:-1]
)
return (indices, indptr)
def _relabel_two_arrays(lhs_array, rhs_array): def _relabel_two_arrays(lhs_array, rhs_array):
"""Relabel two arrays into a consecutive range starting from 0.""" """Relabel two arrays into a consecutive range starting from 0."""
concated = torch.cat([lhs_array, rhs_array]) concated = torch.cat([lhs_array, rhs_array])
...@@ -238,7 +280,7 @@ def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32): ...@@ -238,7 +280,7 @@ def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32):
return torch.nonzero(mask, as_tuple=True)[0] return torch.nonzero(mask, as_tuple=True)[0]
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index.""" """Slice the subgraph according to the index."""
def _index_select(obj, index): def _index_select(obj, index):
...@@ -262,3 +304,34 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -262,3 +304,34 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
subgraph.original_row_node_ids, subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index), _index_select(subgraph.original_edge_ids, index),
) )
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index."""
def _index_select(obj, index):
if obj is None:
return None
if isinstance(obj, CSCFormatBase):
new_indices = obj.indices[index]
new_indptr = torch.searchsorted(index, obj.indptr)
return CSCFormatBase(
indptr=new_indptr,
indices=new_indices,
)
if isinstance(obj, torch.Tensor):
return obj[index]
# Handle the case when obj is a dictionary.
assert isinstance(obj, dict)
assert isinstance(index, dict)
ret = {}
for k, v in obj.items():
ret[k] = _index_select(v, index[k])
return ret
return (
_index_select(subgraph.node_pairs, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
)
...@@ -20,6 +20,12 @@ def _assert_container_equal(lhs, rhs): ...@@ -20,6 +20,12 @@ def _assert_container_equal(lhs, rhs):
assert len(lhs) == len(rhs) assert len(lhs) == len(rhs)
for l, r in zip(lhs, rhs): for l, r in zip(lhs, rhs):
_assert_container_equal(l, r) _assert_container_equal(l, r)
elif isinstance(lhs, gb.CSCFormatBase):
assert isinstance(rhs, gb.CSCFormatBase)
assert len(lhs.indptr) == len(rhs.indptr)
assert len(lhs.indices) == len(rhs.indices)
_assert_container_equal(lhs.indptr, rhs.indptr)
_assert_container_equal(lhs.indices, rhs.indices)
elif isinstance(lhs, dict): elif isinstance(lhs, dict):
assert isinstance(rhs, dict) assert isinstance(rhs, dict)
assert len(lhs) == len(rhs) assert len(lhs) == len(rhs)
...@@ -30,7 +36,7 @@ def _assert_container_equal(lhs, rhs): ...@@ -30,7 +36,7 @@ def _assert_container_equal(lhs, rhs):
@pytest.mark.parametrize("reverse_row", [True, False]) @pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False]) @pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo(reverse_row, reverse_column): def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column):
node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2])) node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2]))
if reverse_row: if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9]) original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
...@@ -75,7 +81,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column): ...@@ -75,7 +81,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@pytest.mark.parametrize("reverse_row", [True, False]) @pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False]) @pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero(reverse_row, reverse_column): def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
node_pairs = { node_pairs = {
"A:relation:B": ( "A:relation:B": (
torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]),
...@@ -141,6 +147,240 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -141,6 +147,240 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal(result.original_edge_ids, expected_edge_ids) _assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 3, 3, 5]),
indices=torch.tensor([0, 3, 3, 2, 2]),
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([24])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([3])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([11])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([2])
original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([2, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1]),
indices=torch.tensor([1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 5]),
indices=torch.tensor([2, 2, 1, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 2, 2]),
indices=torch.tensor([1, 1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "cpu", F._default_context_str == "cpu",
reason="`to` function needs GPU to test.", reason="`to` function needs GPU to test.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment