Commit 481fb996 authored by Mario Geiger's avatar Mario Geiger
Browse files

fix

parent 64a8e2ce
......@@ -8,9 +8,10 @@ def _view(src: SparseTensor, n: int, layout: str = 'csr') -> SparseTensor:
row, col, value = src.coo()
sparse_sizes = src.storage.sparse_sizes()
if sparse_sizes[0] * sparse_sizes[1] % n == 0:
if sparse_sizes[0] * sparse_sizes[1] % n != 0:
raise RuntimeError(
f"shape '[-1, {n}]' is invalid for input of size {sparse_sizes[0] * sparse_sizes[1]}")
f"shape '[-1, {n}]' is invalid for input of size "
f"{sparse_sizes[0] * sparse_sizes[1]}")
assert layout == 'csr' or layout == 'csc'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment