Commit e7f4ef9f authored by rusty1s's avatar rusty1s
Browse files

fix index select

parent 47b719bb
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_index_select(dtype, device):
row = torch.tensor([0, 0, 1, 1, 2])
col = torch.tensor([0, 1, 1, 2, 1])
mat = SparseTensor(row=row, col=col)
print()
print(mat.to_dense())
pass
mat = mat.index_select(0, torch.tensor([0, 2]))
print(mat.to_dense())
import torch
from torch_scatter import gather_csr
from torch_sparse.storage import get_layout
......@@ -9,56 +10,58 @@ def index_select(src, dim, idx):
assert idx.dim() == 1
if dim == 0:
(row, col), value = src.coo()
old_rowptr, col, value = src.csr()
rowcount = src.storage.rowcount
old_rowptr = src.storage.rowptr
rowcount = rowcount[idx]
tmp = torch.arange(rowcount.size(0), device=rowcount.device)
row = tmp.repeat_interleave(rowcount)
# Creates an "arange interleave" tensor of col indices.
rowptr = torch.cat([row.new_zeros(1), rowcount.cumsum(0)], dim=0)
rowptr = col.new_zeros(idx.size(0) + 1)
torch.cumsum(rowcount, dim=0, out=rowptr[1:])
row = torch.arange(idx.size(0),
device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device)
perm += (old_rowptr[idx] - rowptr[:-1])[row]
perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm]
index = torch.stack([row, col], dim=0)
if src.has_value():
value = value[perm]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
sparse_size = torch.Size([idx.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size,
rowcount=rowcount, rowptr=rowptr,
is_sorted=True)
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
elif dim == 1:
old_colptr, row, value = src.csc()
colcount = src.storage.colcount
colcount = colcount[idx]
tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount)
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
colptr = row.new_zeros(idx.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
# Creates an "arange interleave" tensor of row indices.
colptr = torch.cat([col.new_zeros(1), colcount.cumsum(0)], dim=0)
perm = torch.arange(col.size(0), device=col.device)
perm += (old_colptr[idx] - colptr[:-1])[col]
perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm]
csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr]
csc2csr = (idx.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if src.has_value():
value = value[perm][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
sparse_size = torch.Size([src.sparse_size(0), idx.size(0)])
storage = src.storage.__class__(index, value, sparse_size,
colcount=colcount, colptr=colptr,
csc2csr=csc2csr, is_sorted=True)
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
else:
storage = src.storage.apply_value(
......@@ -73,14 +76,15 @@ def index_select_nnz(src, idx, layout=None):
if get_layout(layout) == 'csc':
idx = idx[src.storage.csc2csr]
index, value = src.coo()
row, col, value = src.coo()
row, col = row[idx], col[idx]
index = index[:, idx]
if src.has_value():
value = value[idx]
# There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(),
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(),
is_sorted=True)
return src.from_storage(storage)
......@@ -80,7 +80,7 @@ class SparseStorage(object):
assert col.dim() == 1
if sparse_size is None:
M = rowptr.numel() - 1 if rowptr is None else row.max().item() + 1
M = rowptr.numel() - 1 if row is None else row.max().item() + 1
N = col.max().item() + 1
sparse_size = torch.Size([M, N])
......
......@@ -355,7 +355,7 @@ class SparseTensor(object):
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout="csr"):
def to_scipy(self, layout=None, dtype=None):
assert self.dim() == 2
layout = get_layout(layout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment