Commit 9216364c authored by rusty1s's avatar rusty1s
Browse files

__getitem__ numpy notation

parent 6b9127a0
...@@ -16,7 +16,7 @@ def arange_interleave(start, repeat): ...@@ -16,7 +16,7 @@ def arange_interleave(start, repeat):
def index_select(src, dim, idx): def index_select(src, dim, idx):
dim = src.dim() - dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1 assert idx.dim() == 1
idx = idx.to(src.device) idx = idx.to(src.device)
...@@ -38,8 +38,8 @@ def index_select(src, dim, idx): ...@@ -38,8 +38,8 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
rowcount=rowcount, is_sorted=True) index, value, sparse_size, rowcount=rowcount, is_sorted=True)
elif dim == 1: elif dim == 1:
colptr, row, value = src.csc() colptr, row, value = src.csc()
...@@ -58,13 +58,17 @@ def index_select(src, dim, idx): ...@@ -58,13 +58,17 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
colcount=colcount, csc2csr=csc2csr, index,
value,
sparse_size,
colcount=colcount,
csc2csr=csc2csr,
is_sorted=True) is_sorted=True)
else: else:
storage = src.storage.apply_value( storage = src.storage.apply_value(lambda x: x.index_select(
lambda x: x.index_select(dim - 1, idx)) dim - 1, idx))
return src.from_storage(storage) return src.from_storage(storage)
...@@ -82,7 +86,7 @@ def index_select_nnz(src, idx, layout=None): ...@@ -82,7 +86,7 @@ def index_select_nnz(src, idx, layout=None):
value = value[idx] value = value[idx]
# There is no other information we can maintain... # There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(), storage = src.storage.__class__(
is_sorted=True) index, value, src.sparse_size(), is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
...@@ -4,7 +4,7 @@ from torch_sparse.storage import get_layout ...@@ -4,7 +4,7 @@ from torch_sparse.storage import get_layout
def masked_select(src, dim, mask): def masked_select(src, dim, mask):
dim = src.dim() - dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
assert mask.dim() == 1 assert mask.dim() == 1
storage = src.storage storage = src.storage
...@@ -25,8 +25,8 @@ def masked_select(src, dim, mask): ...@@ -25,8 +25,8 @@ def masked_select(src, dim, mask):
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
rowcount=rowcount, is_sorted=True) index, value, sparse_size, rowcount=rowcount, is_sorted=True)
elif dim == 1: elif dim == 1:
csr2csc = src.storage.csr2csc csr2csc = src.storage.csr2csc
...@@ -48,14 +48,18 @@ def masked_select(src, dim, mask): ...@@ -48,14 +48,18 @@ def masked_select(src, dim, mask):
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
colcount=colcount, csc2csr=csc2csr, index,
value,
sparse_size,
colcount=colcount,
csc2csr=csc2csr,
is_sorted=True) is_sorted=True)
else: else:
idx = mask.nonzero().view(-1) idx = mask.nonzero().view(-1)
storage = src.storage.apply_value( storage = src.storage.apply_value(lambda x: x.index_select(
lambda x: x.index_select(dim - 1, idx)) dim - 1, idx))
return src.from_storage(storage) return src.from_storage(storage)
...@@ -73,7 +77,7 @@ def masked_select_nnz(src, mask, layout=None): ...@@ -73,7 +77,7 @@ def masked_select_nnz(src, mask, layout=None):
value = value[mask] value = value[mask]
# There is no other information we can maintain... # There is no other information we can maintain...
storage = src.storage.__class__(index, value, src.sparse_size(), storage = src.storage.__class__(
is_sorted=True) index, value, src.sparse_size(), is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
...@@ -2,7 +2,8 @@ import torch ...@@ -2,7 +2,8 @@ import torch
def narrow(src, dim, start, length): def narrow(src, dim, start, length):
dim = src.dim() - dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
start = src.size(dim) + start if start < 0 else start
if dim == 0: if dim == 0:
(row, col), value = src.coo() (row, col), value = src.coo()
...@@ -25,8 +26,12 @@ def narrow(src, dim, start, length): ...@@ -25,8 +26,12 @@ def narrow(src, dim, start, length):
value = value.narrow(0, row_start, row_length) value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)]) sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
rowcount=rowcount, rowptr=rowptr, index,
value,
sparse_size,
rowcount=rowcount,
rowptr=rowptr,
is_sorted=True) is_sorted=True)
elif dim == 1: elif dim == 1:
...@@ -50,12 +55,16 @@ def narrow(src, dim, start, length): ...@@ -50,12 +55,16 @@ def narrow(src, dim, start, length):
value = value[mask] value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length]) sparse_size = torch.Size([src.sparse_size(0), length])
storage = src.storage.__class__(index, value, sparse_size, storage = src.storage.__class__(
colcount=colcount, colptr=colptr, index,
value,
sparse_size,
colcount=colcount,
colptr=colptr,
is_sorted=True) is_sorted=True)
else: else:
storage = src.storage.apply_value( storage = src.storage.apply_value(lambda x: x.narrow(
lambda x: x.narrow(dim - 1, start, length)) dim - 1, start, length))
return src.from_storage(storage) return src.from_storage(storage)
...@@ -318,6 +318,48 @@ class SparseTensor(object): ...@@ -318,6 +318,48 @@ class SparseTensor(object):
value = value.detach().cpu().numpy() if self.has_value() else ones value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.size()) return scipy.sparse.csc_matrix((value, row, colptr), self.size())
# Standard Operators ######################################################
def __getitem__(self, index):
index = list(index) if isinstance(index, tuple) else [index]
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError()
dim = 0
out = self
while len(index) > 0:
item = index.pop(0)
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
elif isinstance(item, slice):
if item.step is not None:
raise ValueError('Step parameter not yet supported.')
start = 0 if item.start is None else item.start
start = self.size(dim) + start if start < 0 else start
stop = self.size(dim) if item.stop is None else item.stop
stop = self.size(dim) + stop if stop < 0 else stop
out = out.narrow(dim, start, max(stop - start, 0))
dim += 1
elif torch.is_tensor(item):
if item.dtype == torch.bool:
out = out.masked_select(dim, item)
dim += 1
elif item.dtype == torch.long:
out = out.index_select(dim, item)
dim += 1
elif item == Ellipsis:
if self.dim() - len(index) < dim:
raise SyntaxError()
dim = self.dim() - len(index)
else:
raise SyntaxError()
return out
# String Reputation ####################################################### # String Reputation #######################################################
def __repr__(self): def __repr__(self):
...@@ -457,18 +499,34 @@ if __name__ == '__main__': ...@@ -457,18 +499,34 @@ if __name__ == '__main__':
dataset = Planetoid('/tmp/Cora', 'Cora') dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device) data = dataset[0].to(device)
value = torch.randn((data.num_edges, ), device=device) value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index, value)
mat1 = SparseTensor(data.edge_index, value) index = torch.tensor([0, 1, 2])
mask = torch.zeros(data.num_nodes, dtype=torch.bool)
mask[:3] = True
mat1 = SparseTensor.from_dense(mat1.to_dense()) print(mat[1].size())
print(mat[1, 1].size())
print(mat[..., -1].size())
print(mat[:10, ..., -1].size())
print(mat[:, -1].size())
print(mat[1, :, -1].size())
print(mat[1:4, 1:4].size())
print(mat[index].size())
print(mat[index, index].size())
print(mat[mask, index].size())
# mat[::-1]
# mat[::2]
print(mat1) # mat1 = SparseTensor.from_dense(mat1.to_dense())
mat = SparseTensor.from_torch_sparse_coo_tensor(
mat1.to_torch_sparse_coo_tensor()) # print(mat1)
# mat = SparseTensor.from_torch_sparse_coo_tensor(
# mat1.to_torch_sparse_coo_tensor())
mat = SparseTensor.from_scipy(mat.to_scipy(layout='csc')) # mat = SparseTensor.from_scipy(mat.to_scipy(layout='csc'))
print(mat) # print(mat)
# index = torch.tensor([0, 2]) # index = torch.tensor([0, 2])
# mat2 = mat1.index_select(2, index) # mat2 = mat1.index_select(2, index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment