Commit 4a569c27 authored by rusty1s's avatar rusty1s
Browse files

maintain colptr

parent a28accb6
...@@ -19,20 +19,25 @@ def narrow(src, dim, start, length): ...@@ -19,20 +19,25 @@ def narrow(src, dim, start, length):
sparse_size = torch.Size([length, src.sparse_size(1)]) sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src._storage.__class__( storage = src._storage.__class__(
index, value, sparse_size, rowptr, is_sorted=True) index, value, sparse_size, rowptr=rowptr, is_sorted=True)
elif dim == 1: elif dim == 1:
# This is faster than accessing `csc()` in analogy to the `dim=0` case. # This is faster than accessing `csc()` in analogy to the `dim=0` case.
(row, col), value = src.coo() (row, col), value = src.coo()
mask = (col >= start) & (col < start + length) mask = (col >= start) & (col < start + length)
colptr = src._storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
index = torch.stack([row, col - start], dim=0)[:, mask] index = torch.stack([row, col - start], dim=0)[:, mask]
if src.has_value(): if src.has_value():
value = value[mask] value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length]) sparse_size = torch.Size([src.sparse_size(0), length])
storage = src._storage.__class__( storage = src._storage.__class__(
index, value, sparse_size, is_sorted=True) index, value, sparse_size, colptr=colptr, is_sorted=True)
else: else:
storage = src._storage.apply_value(lambda x: x.narrow( storage = src._storage.apply_value(lambda x: x.narrow(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment