"docs/source/vscode:/vscode.git/clone" did not exist on "cc26cd8139c672016b6a578ea8d02138b53eb193"
Commit 0adb9cab authored by rusty1s's avatar rusty1s
Browse files

fix max check for row.numel() == 0

parent a7063092
...@@ -51,17 +51,23 @@ class SparseStorage(object): ...@@ -51,17 +51,23 @@ class SparseStorage(object):
if sparse_sizes is None: if sparse_sizes is None:
if rowptr is not None: if rowptr is not None:
M = rowptr.numel() - 1 M = rowptr.numel() - 1
elif row is not None: elif row is not None and row.numel() > 0:
M = row.max().item() + 1 M = row.max().item() + 1
elif row is not None and row.numel() == 0:
M = 0
else: else:
raise ValueError raise ValueError
N = col.max().item() + 1 if col.numel() > 0:
N = col.max().item() + 1
else:
N = 0
sparse_sizes = (int(M), int(N)) sparse_sizes = (int(M), int(N))
else: else:
assert len(sparse_sizes) == 2 assert len(sparse_sizes) == 2
if row is not None: if row is not None and row.numel() > 0:
assert row.max().item() < sparse_sizes[0] assert row.max().item() < sparse_sizes[0]
assert col.max().item() < sparse_sizes[1] if col.numel() > 0:
assert col.max().item() < sparse_sizes[1]
if row is not None: if row is not None:
assert row.dtype == torch.long assert row.dtype == torch.long
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment