Commit 592d63d2 authored by rusty1s's avatar rusty1s
Browse files

repr

parent f87afd09
...@@ -65,13 +65,14 @@ def test_jit(): ...@@ -65,13 +65,14 @@ def test_jit():
# adj = Foo(adj.storage.rowptr, adj.storage.col) # adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage # adj = adj.storage
rowptr = torch.tensor([0, 3, 6, 9]) rowptr = torch.tensor([0, 1, 4, 7])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]) col = torch.tensor([0, 0, 1, 2, 0, 1, 2])
adj = SparseTensor(rowptr=rowptr, col=col) adj = SparseTensor(rowptr=rowptr, col=col)
scipy = adj.to_scipy(layout='csr') # scipy = adj.to_scipy(layout='csr')
mat = SparseTensor.from_scipy(scipy) # mat = SparseTensor.from_scipy(scipy)
mat.fill_value_(2.3) print()
print(adj)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col} # adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col) # foo = Foo(mat.storage.rowptr, mat.storage.col)
......
# from textwrap import indent from textwrap import indent
from typing import Optional, List, Tuple, Union from typing import Optional, List, Tuple, Union
import torch import torch
...@@ -382,92 +382,29 @@ class SparseTensor(object): ...@@ -382,92 +382,29 @@ class SparseTensor(object):
return torch.sparse_coo_tensor(index, value, self.sizes()) return torch.sparse_coo_tensor(index, value, self.sizes())
# Standard Operators ######################################################
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
# # Standard Operators ######################################################
# def __getitem__(self, index):
# index = list(index) if isinstance(index, tuple) else [index]
# # More than one `Ellipsis` is not allowed...
# if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
# raise SyntaxError
# dim = 0
# out = self
# while len(index) > 0:
# item = index.pop(0)
# if isinstance(item, int):
# out = out.select(dim, item)
# dim += 1
# elif isinstance(item, slice):
# if item.step is not None:
# raise ValueError('Step parameter not yet supported.')
# start = 0 if item.start is None else item.start
# start = self.size(dim) + start if start < 0 else start
# stop = self.size(dim) if item.stop is None else item.stop
# stop = self.size(dim) + stop if stop < 0 else stop
# out = out.narrow(dim, start, max(stop - start, 0))
# dim += 1
# elif torch.is_tensor(item):
# if item.dtype == torch.bool:
# out = out.masked_select(dim, item)
# dim += 1
# elif item.dtype == torch.long:
# out = out.index_select(dim, item)
# dim += 1
# elif item == Ellipsis:
# if self.dim() - len(index) < dim:
# raise SyntaxError
# dim = self.dim() - len(index)
# else:
# raise SyntaxError
# return out
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
# # String Reputation #######################################################
# def __repr__(self):
# i = ' ' * 6
# row, col, value = self.coo()
# infos = []
# infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
# infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
# if self.has_value():
# infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
# infos += [
# f'size={tuple(self.size())}, '
# f'nnz={self.nnz()}, '
# f'density={100 * self.density():.02f}%'
# ]
# infos = ',\n'.join(infos)
# i = ' ' * (len(self.__class__.__name__) + 1)
# return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
# Bindings #################################################################### # Bindings ####################################################################
...@@ -531,9 +468,77 @@ def to(self, *args, **kwargs): ...@@ -531,9 +468,77 @@ def to(self, *args, **kwargs):
return self return self
@torch.jit.ignore
def __getitem__(self, index):
raise NotImplementedError
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError
dim = 0
out = self
while len(index) > 0:
item = index.pop(0)
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
elif isinstance(item, slice):
if item.step is not None:
raise ValueError('Step parameter not yet supported.')
start = 0 if item.start is None else item.start
start = self.size(dim) + start if start < 0 else start
stop = self.size(dim) if item.stop is None else item.stop
stop = self.size(dim) + stop if stop < 0 else stop
out = out.narrow(dim, start, max(stop - start, 0))
dim += 1
elif torch.is_tensor(item):
if item.dtype == torch.bool:
out = out.masked_select(dim, item)
dim += 1
elif item.dtype == torch.long:
out = out.index_select(dim, item)
dim += 1
elif item == Ellipsis:
if self.dim() - len(index) < dim:
raise SyntaxError
dim = self.dim() - len(index)
else:
raise SyntaxError
return out
@torch.jit.ignore
def __repr__(self):
i = ' ' * 6
row, col, value = self.coo()
infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if value is not None:
infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
infos += [
f'size={tuple(self.sizes())}, '
f'nnz={self.nnz()}, '
f'density={100 * self.density():.02f}%'
]
infos = ',\n'.join(infos)
i = ' ' * (len(self.__class__.__name__) + 1)
return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
SparseTensor.share_memory_ = share_memory_ SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared SparseTensor.is_shared = is_shared
SparseTensor.to = to SparseTensor.to = to
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
# Scipy Conversions ########################################################### # Scipy Conversions ###########################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment