".github/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "03214fd4463fe67f6a515b51e0a9ed056a70e91c"
Commit fcf88282 authored by rusty1s's avatar rusty1s
Browse files

tensor fixes

parent 918b1163
...@@ -37,6 +37,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data, ...@@ -37,6 +37,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// TODO: Make more efficient.
if (thread_idx < numel) { if (thread_idx < numel) {
int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1]; int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
for (int64_t i = idx; i < next_idx; i++) { for (int64_t i = idx; i < next_idx; i++) {
......
...@@ -99,7 +99,6 @@ class SparseStorage(object): ...@@ -99,7 +99,6 @@ class SparseStorage(object):
if value is not None: if value is not None:
assert value.device == col.device assert value.device == col.device
assert value.size(0) == col.size(0) assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None: if rowcount is not None:
assert rowcount.dtype == torch.long assert rowcount.dtype == torch.long
...@@ -160,7 +159,7 @@ class SparseStorage(object): ...@@ -160,7 +159,7 @@ class SparseStorage(object):
def row(self): def row(self):
if self._row is None: if self._row is None:
func = convert_cuda if self.rowptr.is_cuda else convert_cpu func = convert_cuda if self.rowptr.is_cuda else convert_cpu
self._row = func.ptr2ind(self.rowptr, self.nnz()) self._row = func.ptr2ind(self.rowptr, self.col.numel())
return self._row return self._row
def has_rowptr(self): def has_rowptr(self):
...@@ -184,9 +183,9 @@ class SparseStorage(object): ...@@ -184,9 +183,9 @@ class SparseStorage(object):
def value(self): def value(self):
return self._value return self._value
def set_value_(self, value, dtype=None, layout=None): def set_value_(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float): if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), dtype=dtype, value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device) device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': elif torch.is_tensor(value) and get_layout(layout) == 'csc':
...@@ -200,9 +199,9 @@ class SparseStorage(object): ...@@ -200,9 +199,9 @@ class SparseStorage(object):
self._value = value self._value = value
return self return self
def set_value(self, value, dtype=None, layout=None): def set_value(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float): if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), dtype=dtype, value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device) device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': elif torch.is_tensor(value) and get_layout(layout) == 'csc':
...@@ -224,7 +223,7 @@ class SparseStorage(object): ...@@ -224,7 +223,7 @@ class SparseStorage(object):
return self._sparse_size return self._sparse_size
def sparse_resize(self, *sizes): def sparse_resize(self, *sizes):
old_sparse_size, nnz = self.sparse_size, self.nnz() old_sparse_size, nnz = self.sparse_size, self.col.numel()
diff_0 = sizes[0] - old_sparse_size[0] diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr rowcount, rowptr = self._rowcount, self._rowptr
...@@ -258,9 +257,6 @@ class SparseStorage(object): ...@@ -258,9 +257,6 @@ class SparseStorage(object):
colcount=colcount, csr2csc=self._csr2csc, colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
def nnz(self):
return self.col.numel()
def has_rowcount(self): def has_rowcount(self):
return self._rowcount is not None return self._rowcount is not None
......
...@@ -13,11 +13,14 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz ...@@ -13,11 +13,14 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce import torch_sparse.reduce
from torch_sparse.diag import remove_diag from torch_sparse.diag import remove_diag
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, row=None, rowptr=None, col=None, value=None,
self.storage = SparseStorage(index, value, sparse_size, sparse_size=None, is_sorted=False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
is_sorted=is_sorted) is_sorted=is_sorted)
@classmethod @classmethod
...@@ -33,14 +36,15 @@ class SparseTensor(object): ...@@ -33,14 +36,15 @@ class SparseTensor(object):
else: else:
index = mat.nonzero() index = mat.nonzero()
index = index.t().contiguous() row, col = index.t().contiguous()
value = mat[index[0], index[1]] return SparseTensor(row=row, col=col, value=mat[row, col],
return SparseTensor(index, value, mat.size()[:2], is_sorted=True) sparse_size=mat.size()[:2], is_sorted=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False): def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor(mat._indices(), mat._values(), row, col = mat._indices()
mat.size()[:2], is_sorted=is_sorted) return SparseTensor(row=row, col=col, value=mat._values(),
sparse_size=mat.size()[:2], is_sorted=is_sorted)
@classmethod @classmethod
def from_scipy(self, mat): def from_scipy(self, mat):
...@@ -48,60 +52,52 @@ class SparseTensor(object): ...@@ -48,60 +52,52 @@ class SparseTensor(object):
if isinstance(mat, scipy.sparse.csc_matrix): if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long) colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr() mat = mat.tocsr() # Pre-sort.
rowptr = torch.from_numpy(mat.indptr).to(torch.long) rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo() mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long) row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long) col = torch.from_numpy(mat.col).to(torch.long)
index = torch.stack([row, col], dim=0)
value = torch.from_numpy(mat.data) value = torch.from_numpy(mat.data)
size = mat.shape sparse_size = mat.shape[:2]
storage = SparseStorage(index, value, size, rowptr=rowptr, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
colptr=colptr, is_sorted=True) sparse_size=sparse_size, colptr=colptr,
is_sorted=True)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
@classmethod @classmethod
def eye(self, M, N=None, device=None, dtype=None, no_value=False, def eye(self, M, N=None, device=None, dtype=None, has_value=True,
fill_cache=False): fill_cache=False):
N = M if N is None else N N = M if N is None else N
index = torch.empty((2, min(M, N)), dtype=torch.long, device=device) row = torch.arange(min(M, N), device=device)
torch.arange(index.size(1), out=index[0]) rowptr = torch.arange(M + 1, device=device)
torch.arange(index.size(1), out=index[1]) if M > N:
rowptr[row.size(0) + 1:] = row.size(0)
col = row
value = None value = None
if not no_value: if has_value:
value = torch.ones(index.size(1), dtype=dtype, device=device) value = torch.ones(row.size(0), dtype=dtype, device=device)
rowcount = rowptr = colcount = colptr = csr2csc = csc2csr = None rowcount = colptr = colcount = csr2csc = csc2csr = None
if fill_cache: if fill_cache:
rowcount = index.new_ones(M) rowcount = row.new_ones(M)
rowptr = torch.arange(M + 1, device=device)
if M > N: if M > N:
rowcount[index.size(1):] = 0 rowcount[row.size(0):] = 0
rowptr[index.size(1) + 1:] = index.size(1)
colcount = index.new_ones(N)
colptr = torch.arange(N + 1, device=device) colptr = torch.arange(N + 1, device=device)
colcount = col.new_ones(N)
if N > M: if N > M:
colcount[index.size(1):] = 0 colptr[col.size(0) + 1:] = col.size(0)
colptr[index.size(1) + 1:] = index.size(1) colcount[col.size(0):] = 0
csr2csc = torch.arange(index.size(1), device=device) csr2csc = csc2csr = row
csc2csr = torch.arange(index.size(1), device=device)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
storage = SparseStorage( sparse_size=torch.Size([M, N]),
index, rowcount=rowcount, colptr=colptr,
value, colcount=colcount, csr2csc=csr2csc,
torch.Size([M, N]), csc2csr=csc2csr, is_sorted=True)
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
def __copy__(self): def __copy__(self):
...@@ -118,7 +114,7 @@ class SparseTensor(object): ...@@ -118,7 +114,7 @@ class SparseTensor(object):
# Formats ################################################################# # Formats #################################################################
def coo(self): def coo(self):
return self.storage.index, self.storage.value return self.storage.row, self.storage.col, self.storage.value
def csr(self): def csr(self):
return self.storage.rowptr, self.storage.col, self.storage.value return self.storage.rowptr, self.storage.col, self.storage.value
...@@ -133,15 +129,16 @@ class SparseTensor(object): ...@@ -133,15 +129,16 @@ class SparseTensor(object):
def has_value(self): def has_value(self):
return self.storage.has_value() return self.storage.has_value()
def set_value_(self, value, layout=None): def set_value_(self, value, layout=None, dtype=None):
self.storage.set_value_(value, layout) self.storage.set_value_(value, layout, dtype)
return self return self
def set_value(self, value, layout=None): def set_value(self, value, layout=None, dtype=None):
return self.from_storage(self.storage.set_value(value, layout)) return self.from_storage(self.storage.set_value(value, layout, dtype))
def sparse_size(self, dim=None): def sparse_size(self, dim=None):
return self.storage.sparse_size(dim) sparse_size = self.storage.sparse_size
return sparse_size if dim is None else sparse_size[dim]
def sparse_resize(self, *sizes): def sparse_resize(self, *sizes):
return self.from_storage(self.storage.sparse_resize(*sizes)) return self.from_storage(self.storage.sparse_resize(*sizes))
...@@ -165,20 +162,20 @@ class SparseTensor(object): ...@@ -165,20 +162,20 @@ class SparseTensor(object):
# Utility functions ####################################################### # Utility functions #######################################################
def dim(self):
return len(self.size())
def size(self, dim=None): def size(self, dim=None):
size = self.sparse_size() size = self.sparse_size()
size += self.storage.value.size()[1:] if self.has_value() else () size += self.storage.value.size()[1:] if self.has_value() else ()
return size if dim is None else size[dim] return size if dim is None else size[dim]
def dim(self):
return len(self.size())
@property @property
def shape(self): def shape(self):
return self.size() return self.size()
def nnz(self): def nnz(self):
return self.storage.nnz() return self.storage.col.numel()
def density(self): def density(self):
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1)) return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
...@@ -202,11 +199,16 @@ class SparseTensor(object): ...@@ -202,11 +199,16 @@ class SparseTensor(object):
if not self.is_quadratic: if not self.is_quadratic:
return False return False
rowptr, col, val1 = self.csr() rowptr, col, value1 = self.csr()
colptr, row, val2 = self.csc() colptr, row, value2 = self.csc()
index_sym = (rowptr == colptr).all() and (col == row).all()
value_sym = (val1 == val2).all().item() if self.has_value() else True if (rowptr != colptr).any() or (col != row).any():
return index_sym.item() and value_sym return False
if not self.has_value():
return True
return (value1 == value2).all().item()
def detach_(self): def detach_(self):
self.storage.apply_(lambda x: x.detach_()) self.storage.apply_(lambda x: x.detach_())
...@@ -219,9 +221,13 @@ class SparseTensor(object): ...@@ -219,9 +221,13 @@ class SparseTensor(object):
def requires_grad(self): def requires_grad(self):
return self.storage.value.requires_grad if self.has_value() else False return self.storage.value.requires_grad if self.has_value() else False
def requires_grad_(self, requires_grad=True): def requires_grad_(self, requires_grad=True, dtype=None):
if requires_grad and not self.has_value():
self.storage.set_value_(1, dtype=dtype)
if self.has_value(): if self.has_value():
self.storage.value.requires_grad_(requires_grad) self.storage.value.requires_grad_(requires_grad)
return self return self
def pin_memory(self): def pin_memory(self):
...@@ -239,7 +245,7 @@ class SparseTensor(object): ...@@ -239,7 +245,7 @@ class SparseTensor(object):
@property @property
def device(self): def device(self):
return self.storage.index.device return self.storage.col.device
def cpu(self): def cpu(self):
return self.from_storage(self.storage.apply(lambda x: x.cpu())) return self.from_storage(self.storage.apply(lambda x: x.cpu()))
...@@ -251,7 +257,7 @@ class SparseTensor(object): ...@@ -251,7 +257,7 @@ class SparseTensor(object):
@property @property
def is_cuda(self): def is_cuda(self):
return self.storage.index.is_cuda return self.storage.col.is_cuda
@property @property
def dtype(self): def dtype(self):
...@@ -296,7 +302,7 @@ class SparseTensor(object): ...@@ -296,7 +302,7 @@ class SparseTensor(object):
if len(args) > 0 or len(kwargs) > 0: if len(args) > 0 or len(kwargs) > 0:
storage = storage.apply_value(lambda x: x.type(*args, **kwargs)) storage = storage.apply_value(lambda x: x.type(*args, **kwargs))
if storage == self.storage: # Nothing changed... if storage == self.storage: # Nothing has been changed...
return self return self
else: else:
return self.from_storage(storage) return self.from_storage(storage)
...@@ -335,19 +341,21 @@ class SparseTensor(object): ...@@ -335,19 +341,21 @@ class SparseTensor(object):
def to_dense(self, dtype=None): def to_dense(self, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
(row, col), value = self.coo() row, col, value = self.coo()
mat = torch.zeros(self.size(), dtype=dtype, device=self.device) mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
mat[row, col] = value if self.has_value() else 1 mat[row, col] = value if self.has_value() else 1
return mat return mat
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo() row, col, value = self.coo()
return torch.sparse_coo_tensor( index = torch.stack([row, col], dim=0)
index, value if self.has_value() else torch.ones( if value is None:
self.nnz(), dtype=dtype, device=self.device), self.size(), value = torch.ones(self.nnz(), dtype=dtype, device=self.device)
device=self.device, requires_grad=requires_grad) return torch.sparse_coo_tensor(index, value, self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None): def to_scipy(self, dtype=None, layout="csr"):
assert self.dim() == 2 assert self.dim() == 2
layout = get_layout(layout) layout = get_layout(layout)
...@@ -355,7 +363,7 @@ class SparseTensor(object): ...@@ -355,7 +363,7 @@ class SparseTensor(object):
ones = torch.ones(self.nnz(), dtype=dtype).numpy() ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo': if layout == 'coo':
(row, col), value = self.coo() row, col, value = self.coo()
row = row.detach().cpu().numpy() row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy() col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones value = value.detach().cpu().numpy() if self.has_value() else ones
...@@ -379,7 +387,7 @@ class SparseTensor(object): ...@@ -379,7 +387,7 @@ class SparseTensor(object):
index = list(index) if isinstance(index, tuple) else [index] index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed... # More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1: if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError() raise SyntaxError
dim = 0 dim = 0
out = self out = self
...@@ -416,6 +424,15 @@ class SparseTensor(object): ...@@ -416,6 +424,15 @@ class SparseTensor(object):
return out return out
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
def __iadd__(self, other):
return self.add_(other)
def __matmul__(a, b): def __matmul__(a, b):
return matmul(a, b, reduce='sum') return matmul(a, b, reduce='sum')
...@@ -423,8 +440,10 @@ class SparseTensor(object): ...@@ -423,8 +440,10 @@ class SparseTensor(object):
def __repr__(self): def __repr__(self):
i = ' ' * 6 i = ' ' * 6
index, value = self.coo() row, col, value = self.coo()
infos = [f'index={indent(index.__repr__(), i)[len(i):]}'] infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if self.has_value(): if self.has_value():
infos += [f'value={indent(value.__repr__(), i)[len(i):]}'] infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
...@@ -456,8 +475,10 @@ SparseTensor.min = torch_sparse.reduce.min ...@@ -456,8 +475,10 @@ SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag SparseTensor.remove_diag = remove_diag
SparseTensor.matmul = matmul SparseTensor.matmul = matmul
# SparseTensor.add = add SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_
# def __add__(self, other): # def __add__(self, other):
# return self.add(other) # return self.add(other)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment