Unverified Commit 9b69ff84 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from `exclude_edges`. (#6848)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 884a378c
......@@ -162,17 +162,7 @@ class SampledSubgraph:
# 1. Convert the node pairs to the original ids if they are compacted.
# 2. Exclude the edges and get the index of the edges to keep.
# 3. Slice the subgraph according to the index.
if isinstance(self.sampled_csc, tuple):
reverse_edges = _to_reverse_ids_node_pairs(
self.sampled_csc,
self.original_row_node_ids,
self.original_column_node_ids,
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph_node_pairs(self, index))
elif isinstance(self.sampled_csc, CSCFormatBase):
if isinstance(self.sampled_csc, CSCFormatBase):
reverse_edges = _to_reverse_ids(
self.sampled_csc,
self.original_row_node_ids,
......@@ -184,7 +174,6 @@ class SampledSubgraph:
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
is_cscformat = 0
for etype, pair in self.sampled_csc.items():
if etype not in edges:
# No edges need to be excluded.
......@@ -201,28 +190,17 @@ class SampledSubgraph:
if self.original_column_node_ids is None
else self.original_column_node_ids.get(dst_type)
)
if isinstance(pair, CSCFormatBase):
is_cscformat = 1
reverse_edges = _to_reverse_ids(
pair,
original_row_node_ids,
original_column_node_ids,
)
else:
reverse_edges = _to_reverse_ids_node_pairs(
pair,
original_row_node_ids,
original_column_node_ids,
)
reverse_edges = _to_reverse_ids(
pair,
original_row_node_ids,
original_column_node_ids,
)
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
if is_cscformat:
return calling_class(*_slice_subgraph(self, index))
else:
return calling_class(*_slice_subgraph_node_pairs(self, index))
return calling_class(*_slice_subgraph(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection."""
......@@ -241,17 +219,6 @@ class SampledSubgraph:
return self
def _to_reverse_ids_node_pairs(
node_pair, original_row_node_ids, original_column_node_ids
):
u, v = node_pair
if original_row_node_ids is not None:
u = original_row_node_ids[u]
if original_column_node_ids is not None:
v = original_column_node_ids[v]
return (u, v)
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr
indices = node_pair.indices
......@@ -293,34 +260,6 @@ def _exclude_homo_edges(
return torch.nonzero(mask, as_tuple=True)[0]
def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index."""
def _index_select(obj, index):
if obj is None:
return None
if index is None:
return obj
if isinstance(obj, torch.Tensor):
return obj[index]
if isinstance(obj, tuple):
return tuple(_index_select(v, index) for v in obj)
# Handle the case when obj is a dictionary.
assert isinstance(obj, dict)
assert isinstance(index, dict)
ret = {}
for k, v in obj.items():
ret[k] = _index_select(v, index[k])
return ret
return (
_index_select(subgraph.sampled_csc, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
)
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
"""Slice the subgraph according to the index."""
......
......@@ -34,119 +34,6 @@ def _assert_container_equal(lhs, rhs):
_assert_container_equal(value, rhs[key])
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column):
node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2]))
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2])
if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = FusedSampledSubgraphImpl(
node_pairs,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
expected_node_pairs = (torch.tensor([0, 3]), torch.tensor([1, 2]))
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10])
_assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
node_pairs = {
"A:relation:B": (
torch.tensor([0, 1, 2]),
torch.tensor([2, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl(
sampled_csc=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_node_pairs = {
"A:relation:B": (
torch.tensor([1]),
torch.tensor([1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
......@@ -387,10 +274,10 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
)
def test_sampled_subgraph_to_device():
# Initialize data.
node_pairs = {
"A:relation:B": (
torch.tensor([0, 1, 2]),
torch.tensor([2, 1, 0]),
csc_format = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 2]),
)
}
original_row_node_ids = {
......@@ -402,8 +289,8 @@ def test_sampled_subgraph_to_device():
}
dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl(
sampled_csc=node_pairs,
subgraph = SampledSubgraphImpl(
sampled_csc=csc_format,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
......@@ -421,8 +308,8 @@ def test_sampled_subgraph_to_device():
# Check.
for key in graph.sampled_csc:
assert graph.sampled_csc[key][0].device.type == "cuda"
assert graph.sampled_csc[key][1].device.type == "cuda"
assert graph.sampled_csc[key].indices.device.type == "cuda"
assert graph.sampled_csc[key].indptr.device.type == "cuda"
for key in graph.original_column_node_ids:
assert graph.original_column_node_ids[key].device.type == "cuda"
for key in graph.original_row_node_ids:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment