Commit bc184b1a authored by rusty1s's avatar rusty1s
Browse files

slightly faster narrow

parent 0fd9cfe2
import copy
from typing import Tuple
import torch
......@@ -85,13 +86,14 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
row_start = rowptr[0]
row_start = int(rowptr[0])
rowptr = rowptr - row_start
row_length = rowptr[-1]
row_length = int(rowptr[-1])
row = src.storage._row
if row is not None:
......@@ -111,7 +113,7 @@ def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start[1], length[1] + 1)
colptr = colptr - colptr[0] # i.e. `row_start`
colptr = colptr - int(colptr[0]) # i.e. `row_start`
colcount = src.storage._colcount
if colcount is not None:
......
......@@ -144,12 +144,12 @@ class SparseStorage(object):
self._csc2csr = csc2csr
if not is_sorted:
idx = col.new_zeros(col.numel() + 1)
idx[1:] = sparse_sizes[1] * self.row() + col
idx = self._col.new_zeros(self._col.numel() + 1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row()[perm]
self._col = col[perm]
self._col = self._col[perm]
if value is not None:
self._value = value[perm]
self._csr2csc = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment