Commit 0adb9cab authored by rusty1s's avatar rusty1s
Browse files

fix max check for row.numel() == 0

parent a7063092
...@@ -51,16 +51,22 @@ class SparseStorage(object): ...@@ -51,16 +51,22 @@ class SparseStorage(object):
if sparse_sizes is None: if sparse_sizes is None:
if rowptr is not None: if rowptr is not None:
M = rowptr.numel() - 1 M = rowptr.numel() - 1
elif row is not None: elif row is not None and row.numel() > 0:
M = row.max().item() + 1 M = row.max().item() + 1
elif row is not None and row.numel() == 0:
M = 0
else: else:
raise ValueError raise ValueError
if col.numel() > 0:
N = col.max().item() + 1 N = col.max().item() + 1
else:
N = 0
sparse_sizes = (int(M), int(N)) sparse_sizes = (int(M), int(N))
else: else:
assert len(sparse_sizes) == 2 assert len(sparse_sizes) == 2
if row is not None: if row is not None and row.numel() > 0:
assert row.max().item() < sparse_sizes[0] assert row.max().item() < sparse_sizes[0]
if col.numel() > 0:
assert col.max().item() < sparse_sizes[1] assert col.max().item() < sparse_sizes[1]
if row is not None: if row is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment