Commit e7f4ef9f authored by rusty1s's avatar rusty1s
Browse files

fix index select

parent 47b719bb
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_index_select(dtype, device):
row = torch.tensor([0, 0, 1, 1, 2])
col = torch.tensor([0, 1, 1, 2, 1])
mat = SparseTensor(row=row, col=col)
print()
print(mat.to_dense())
pass
mat = mat.index_select(0, torch.tensor([0, 2]))
print(mat.to_dense())
import torch import torch
from torch_scatter import gather_csr
from torch_sparse.storage import get_layout from torch_sparse.storage import get_layout
...@@ -9,56 +10,58 @@ def index_select(src, dim, idx): ...@@ -9,56 +10,58 @@ def index_select(src, dim, idx):
assert idx.dim() == 1 assert idx.dim() == 1
if dim == 0: if dim == 0:
(row, col), value = src.coo() old_rowptr, col, value = src.csr()
rowcount = src.storage.rowcount rowcount = src.storage.rowcount
old_rowptr = src.storage.rowptr
rowcount = rowcount[idx] rowcount = rowcount[idx]
tmp = torch.arange(rowcount.size(0), device=rowcount.device)
row = tmp.repeat_interleave(rowcount)
# Creates an "arange interleave" tensor of col indices. rowptr = col.new_zeros(idx.size(0) + 1)
rowptr = torch.cat([row.new_zeros(1), rowcount.cumsum(0)], dim=0) torch.cumsum(rowcount, dim=0, out=rowptr[1:])
row = torch.arange(idx.size(0),
device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device) perm = torch.arange(row.size(0), device=row.device)
perm += (old_rowptr[idx] - rowptr[:-1])[row] perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm] col = col[perm]
index = torch.stack([row, col], dim=0)
if src.has_value(): if src.has_value():
value = value[perm] value = value[perm]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_size = torch.Size([idx.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
rowcount=rowcount, rowptr=rowptr, value=value, sparse_size=sparse_size,
is_sorted=True) rowcount=rowcount, is_sorted=True)
elif dim == 1: elif dim == 1:
old_colptr, row, value = src.csc() old_colptr, row, value = src.csc()
colcount = src.storage.colcount colcount = src.storage.colcount
colcount = colcount[idx] colcount = colcount[idx]
tmp = torch.arange(colcount.size(0), device=row.device) col = torch.arange(idx.size(0),
col = tmp.repeat_interleave(colcount) device=row.device).repeat_interleave(colcount)
colptr = row.new_zeros(idx.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
# Creates an "arange interleave" tensor of row indices.
colptr = torch.cat([col.new_zeros(1), colcount.cumsum(0)], dim=0)
perm = torch.arange(col.size(0), device=col.device) perm = torch.arange(col.size(0), device=col.device)
perm += (old_colptr[idx] - colptr[:-1])[col] perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm] row = row[perm]
csc2csr = (colcount.size(0) * row + col).argsort() csc2csr = (idx.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr] row, col = row[csc2csr], col[csc2csr]
if src.has_value(): if src.has_value():
value = value[perm][csc2csr] value = value[perm][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_size = torch.Size([src.sparse_size(0), idx.size(0)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(row=row, col=col, value=value,
colcount=colcount, colptr=colptr, sparse_size=sparse_size, colptr=colptr,
csc2csr=csc2csr, is_sorted=True) colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
else: else:
storage = src.storage.apply_value( storage = src.storage.apply_value(
...@@ -73,14 +76,15 @@ def index_select_nnz(src, idx, layout=None): ...@@ -73,14 +76,15 @@ def index_select_nnz(src, idx, layout=None):
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
idx = idx[src.storage.csc2csr] idx = idx[src.storage.csc2csr]
index, value = src.coo() row, col, value = src.coo()
row, col = row[idx], col[idx]
index = index[:, idx]
if src.has_value(): if src.has_value():
value = value[idx] value = value[idx]
# There is no other information we can maintain... # There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(), storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(),
is_sorted=True) is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
...@@ -80,7 +80,7 @@ class SparseStorage(object): ...@@ -80,7 +80,7 @@ class SparseStorage(object):
assert col.dim() == 1 assert col.dim() == 1
if sparse_size is None: if sparse_size is None:
M = rowptr.numel() - 1 if rowptr is None else row.max().item() + 1 M = rowptr.numel() - 1 if row is None else row.max().item() + 1
N = col.max().item() + 1 N = col.max().item() + 1
sparse_size = torch.Size([M, N]) sparse_size = torch.Size([M, N])
......
...@@ -355,7 +355,7 @@ class SparseTensor(object): ...@@ -355,7 +355,7 @@ class SparseTensor(object):
device=self.device, device=self.device,
requires_grad=requires_grad) requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout="csr"): def to_scipy(self, layout=None, dtype=None):
assert self.dim() == 2 assert self.dim() == 2
layout = get_layout(layout) layout = get_layout(layout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment