Commit 47b719bb authored by rusty1s's avatar rusty1s
Browse files

narrow fix

parent 63470ce3
......@@ -6,8 +6,8 @@ def narrow(src, dim, start, length):
start = src.size(dim) + start if start < 0 else start
if dim == 0:
(row, col), value = src.coo()
rowptr = src.storage.rowptr
rowptr, col, value = src.csr()
# rowptr = src.storage.rowptr
# Maintain `rowcount`...
rowcount = src.storage._rowcount
......@@ -19,26 +19,27 @@ def narrow(src, dim, start, length):
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = row.narrow(0, row_start, row_length) - start
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
index = torch.stack([row, col], dim=0)
if src.has_value():
value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src.storage.__class__(
index,
value,
sparse_size,
rowcount=rowcount,
rowptr=rowptr,
is_sorted=True)
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
(row, col), value = src.coo()
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row, col = row[mask], col[mask] - start
# Maintain `colcount`...
colcount = src.storage._colcount
if colcount is not None:
......@@ -50,21 +51,17 @@ def narrow(src, dim, start, length):
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
index = torch.stack([row, col - start], dim=0)[:, mask]
if src.has_value():
value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length])
storage = src.storage.__class__(
index,
value,
sparse_size,
colcount=colcount,
colptr=colptr,
is_sorted=True)
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
colcount=colcount, is_sorted=True)
else:
storage = src.storage.apply_value(lambda x: x.narrow(
dim - 1, start, length))
storage = src.storage.apply_value(
lambda x: x.narrow(dim - 1, start, length))
return src.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment