Unverified Commit 109aed56 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check in `unique_and_compact_csc_format`. (#6744)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 0e188b61
...@@ -248,7 +248,13 @@ def unique_and_compact_csc_formats( ...@@ -248,7 +248,13 @@ def unique_and_compact_csc_formats(
# Collect all source and destination nodes for each node type. # Collect all source and destination nodes for each node type.
indices = defaultdict(list) indices = defaultdict(list)
for etype, csc_format in csc_formats.items(): for etype, csc_format in csc_formats.items():
src_type, _, _ = etype_str_to_tuple(etype) assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
), "The seed nodes should correspond to indptr."
indices[src_type].append(csc_format.indices) indices[src_type].append(csc_format.indices)
indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()} indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()}
...@@ -364,7 +370,7 @@ def compact_csc_format( ...@@ -364,7 +370,7 @@ def compact_csc_format(
csc_format.indices csc_format.indices
), "The last element of indptr should be the same as the length of indices." ), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes[dst_type]) + 1 == len( assert len(dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr csc_format.indptr
), "The seed nodes should correspond to indptr." ), "The seed nodes should correspond to indptr."
offset = original_row_ids.get(src_type, torch.tensor([])).size(0) offset = original_row_ids.get(src_type, torch.tensor([])).size(0)
......
...@@ -267,7 +267,7 @@ def test_unique_and_compact_csc_formats_hetero(): ...@@ -267,7 +267,7 @@ def test_unique_and_compact_csc_formats_hetero():
def test_unique_and_compact_csc_formats_homo(): def test_unique_and_compact_csc_formats_homo():
seeds = torch.tensor([1, 3, 5, 2, 6]) seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 10, 11]) indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6]) indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices) csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
...@@ -286,6 +286,25 @@ def test_unique_and_compact_csc_formats_homo(): ...@@ -286,6 +286,25 @@ def test_unique_and_compact_csc_formats_homo():
assert torch.equal(unique_nodes, expected_unique_nodes) assert torch.equal(unique_nodes, expected_unique_nodes)
def test_unique_and_compact_incorrect_indptr():
seeds = torch.tensor([1, 3, 5, 2, 6, 7])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
def test_compact_csc_format_hetero(): def test_compact_csc_format_hetero():
dst_nodes = { dst_nodes = {
"n2": torch.tensor([2, 4, 1, 3]), "n2": torch.tensor([2, 4, 1, 3]),
...@@ -365,3 +384,22 @@ def test_compact_csc_format_homo(): ...@@ -365,3 +384,22 @@ def test_compact_csc_format_homo():
assert torch.equal(indptr, expected_indptr) assert torch.equal(indptr, expected_indptr)
assert torch.equal(indices, expected_indices) assert torch.equal(indices, expected_indices)
assert torch.equal(original_row_ids, expected_original_row_ids) assert torch.equal(original_row_ids, expected_original_row_ids)
def test_compact_incorrect_indptr():
seeds = torch.tensor([1, 3, 5, 2, 6, 7])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment