Unverified Commit 109aed56 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check in `unique_and_compact_csc_format`. (#6744)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 0e188b61
......@@ -248,7 +248,13 @@ def unique_and_compact_csc_formats(
# Collect all source and destination nodes for each node type.
indices = defaultdict(list)
for etype, csc_format in csc_formats.items():
src_type, _, _ = etype_str_to_tuple(etype)
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
), "The seed nodes should correspond to indptr."
indices[src_type].append(csc_format.indices)
indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()}
......@@ -364,7 +370,7 @@ def compact_csc_format(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes[dst_type]) + 1 == len(
assert len(dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
), "The seed nodes should correspond to indptr."
offset = original_row_ids.get(src_type, torch.tensor([])).size(0)
......
......@@ -267,7 +267,7 @@ def test_unique_and_compact_csc_formats_hetero():
def test_unique_and_compact_csc_formats_homo():
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 10, 11])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
......@@ -286,6 +286,25 @@ def test_unique_and_compact_csc_formats_homo():
assert torch.equal(unique_nodes, expected_unique_nodes)
def test_unique_and_compact_incorrect_indptr():
seeds = torch.tensor([1, 3, 5, 2, 6, 7])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
def test_compact_csc_format_hetero():
dst_nodes = {
"n2": torch.tensor([2, 4, 1, 3]),
......@@ -365,3 +384,22 @@ def test_compact_csc_format_homo():
assert torch.equal(indptr, expected_indptr)
assert torch.equal(indices, expected_indices)
assert torch.equal(original_row_ids, expected_original_row_ids)
def test_compact_incorrect_indptr():
seeds = torch.tensor([1, 3, 5, 2, 6, 7])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment