Commit 47b719bb authored by rusty1s's avatar rusty1s
Browse files

narrow fix

parent 63470ce3
...@@ -6,8 +6,8 @@ def narrow(src, dim, start, length): ...@@ -6,8 +6,8 @@ def narrow(src, dim, start, length):
start = src.size(dim) + start if start < 0 else start start = src.size(dim) + start if start < 0 else start
if dim == 0: if dim == 0:
(row, col), value = src.coo() rowptr, col, value = src.csr()
rowptr = src.storage.rowptr # rowptr = src.storage.rowptr
# Maintain `rowcount`... # Maintain `rowcount`...
rowcount = src.storage._rowcount rowcount = src.storage._rowcount
...@@ -19,26 +19,27 @@ def narrow(src, dim, start, length): ...@@ -19,26 +19,27 @@ def narrow(src, dim, start, length):
rowptr = rowptr - row_start rowptr = rowptr - row_start
row_length = rowptr[-1] row_length = rowptr[-1]
row = row.narrow(0, row_start, row_length) - start row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length) col = col.narrow(0, row_start, row_length)
index = torch.stack([row, col], dim=0)
if src.has_value(): if src.has_value():
value = value.narrow(0, row_start, row_length) value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)]) sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src.storage.__class__( storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
index, value=value, sparse_size=sparse_size,
value, rowcount=rowcount, is_sorted=True)
sparse_size,
rowcount=rowcount,
rowptr=rowptr,
is_sorted=True)
elif dim == 1: elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case. # This is faster than accessing `csc()` contrary to the `dim=0` case.
(row, col), value = src.coo() row, col, value = src.coo()
mask = (col >= start) & (col < start + length) mask = (col >= start) & (col < start + length)
row, col = row[mask], col[mask] - start
# Maintain `colcount`... # Maintain `colcount`...
colcount = src.storage._colcount colcount = src.storage._colcount
if colcount is not None: if colcount is not None:
...@@ -50,21 +51,17 @@ def narrow(src, dim, start, length): ...@@ -50,21 +51,17 @@ def narrow(src, dim, start, length):
colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0] colptr = colptr - colptr[0]
index = torch.stack([row, col - start], dim=0)[:, mask]
if src.has_value(): if src.has_value():
value = value[mask] value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length]) sparse_size = torch.Size([src.sparse_size(0), length])
storage = src.storage.__class__( storage = src.storage.__class__(row=row, col=col, value=value,
index, sparse_size=sparse_size, colptr=colptr,
value, colcount=colcount, is_sorted=True)
sparse_size,
colcount=colcount,
colptr=colptr,
is_sorted=True)
else: else:
storage = src.storage.apply_value(lambda x: x.narrow( storage = src.storage.apply_value(
dim - 1, start, length)) lambda x: x.narrow(dim - 1, start, length))
return src.from_storage(storage) return src.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment