Unverified Commit 93a58343 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Remove unnecessary check and synchronization (#6863)

parent a2cb2ecd
......@@ -164,6 +164,14 @@ class CSCFormatBase:
indptr: torch.Tensor = None
indices: torch.Tensor = None
def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):
self.indptr = indptr
self.indices = indices
if not indptr.is_cuda:
assert self.indptr[-1] == len(
self.indices
), "The last element of indptr should be the same as the length of indices."
def __repr__(self) -> str:
return _csc_format_base_str(self)
......
......@@ -254,9 +254,6 @@ def unique_and_compact_csc_formats(
for etype, csc_format in csc_formats.items():
if device is None:
device = csc_format.indices.device
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
......@@ -358,9 +355,6 @@ def compact_csc_format(
assert isinstance(
dst_nodes, torch.Tensor
), "Edge type not supported in homogeneous graph."
assert csc_formats.indptr[-1] == len(
csc_formats.indices
), "The last element of indptr should be the same as the length of indices."
assert len(dst_nodes) + 1 == len(
csc_formats.indptr
), "The seed nodes should correspond to indptr."
......@@ -381,9 +375,6 @@ def compact_csc_format(
compacted_csc_formats = {}
original_row_ids = copy.deepcopy(dst_nodes)
for etype, csc_format in csc_formats.items():
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
......
......@@ -202,7 +202,7 @@ class MiniBatch:
v.indices,
torch.arange(
0,
v.indptr[-1],
len(v.indices),
device=v.indptr.device,
dtype=v.indptr.dtype,
),
......@@ -227,7 +227,7 @@ class MiniBatch:
sampled_csc.indices,
torch.arange(
0,
sampled_csc.indptr[-1],
len(sampled_csc.indices),
device=sampled_csc.indptr.device,
dtype=sampled_csc.indptr.dtype,
),
......
......@@ -244,3 +244,11 @@ def test_csc_format_base_representation():
)"""
)
assert str(csc_format_base) == expected_result, print(csc_format_base)
def test_csc_format_base_incorrect_indptr():
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
with pytest.raises(AssertionError):
# The value of last element in indptr is not corresponding to indices.
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
......@@ -350,14 +350,6 @@ def test_unique_and_compact_incorrect_indptr():
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)
def test_compact_csc_format_hetero():
dst_nodes = {
......@@ -449,11 +441,3 @@ def test_compact_incorrect_indptr():
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment