Commit f59fe649 authored by rusty1s's avatar rusty1s
Browse files

beginning of torch script support

parent c4484dbb
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.storage import SparseStorage
from typing import Dict, Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@torch.jit.script
class Foo:
rowptr: torch.Tensor
col: torch.Tensor
def __init__(self, rowptr: torch.Tensor, col: torch.Tensor):
self.rowptr = rowptr
self.col = col
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseStorage) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.rowptr(), adj.col(), None,
x, 'sum')
return out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def test_jit():
my_cell = MyCell()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x = torch.randn(3, 2)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr = torch.tensor([0, 3, 6, 9])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
adj = SparseStorage(rowptr=rowptr, col=col)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
out = traced_cell(x, adj)
print(out)
# # print(traced_cell.code)
import torch
import torch_scatter
from .unique import unique
# from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0):
......@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise NotImplementedError
row, col = index
......
import torch
from torch_sparse.utils import ext
def remove_diag(src, k=0):
row, col, value = src.coo()
......@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row, col, value = src.coo()
mask = ext(row.is_cuda).non_diag_mask(row, col, src.size(0), src.size(1),
k)
if row.is_cuda:
mask = torch.ops.torch_sparse_cuda.non_diag_mask(
row, col, src.size(0), src.size(1), k)
else:
mask = torch.ops.torch_sparse_cpu.non_diag_mask(
row, col, src.size(0), src.size(1), k)
inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
......
import torch
import scipy.sparse
from torch_scatter import scatter_add
from torch_sparse.utils import ext
ext = None
class SPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce):
out, arg_out = ext(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
if mat.is_cuda:
out, arg_out = torch.ops.torch_sparse_cuda.spmm(
rowptr, col, value, mat, reduce)
else:
out, arg_out = torch.ops.torch_sparse_cpu.spmm(
rowptr, col, value, mat, reduce)
ctx.reduce = reduce
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
......
import torch
from torch_sparse import transpose, to_scipy, from_scipy, coalesce
import torch_sparse.spspmm_cpu
# import torch_sparse.spspmm_cpu
if torch.cuda.is_available():
import torch_sparse.spspmm_cuda
# if torch.cuda.is_available():
# import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
......@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise NotImplementedError
if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n)
......
import warnings
from typing import Optional, List, Dict, Any
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import ext
from torch_sparse.utils import Final
__cache__ = {'enabled': True}
......@@ -32,24 +33,24 @@ class no_cache(object):
return decorate_no_cache
class cached_property(object):
def __init__(self, func):
self.func = func
# class cached_property(object):
# def __init__(self, func):
# self.func = func
def __get__(self, obj, cls):
value = getattr(obj, f'_{self.func.__name__}', None)
if value is None:
value = self.func(obj)
if is_cache_enabled():
setattr(obj, f'_{self.func.__name__}', value)
return value
# def __get__(self, obj, cls):
# value = getattr(obj, f'_{self.func.__name__}', None)
# if value is None:
# value = self.func(obj)
# if is_cache_enabled():
# setattr(obj, f'_{self.func.__name__}', value)
# return value
def optional(func, src):
return func(src) if src is not None else src
layouts = ['coo', 'csr', 'csc']
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout=None):
......@@ -61,12 +62,30 @@ def get_layout(layout=None):
return layout
@torch.jit.script
class SparseStorage(object):
cache_keys = ['rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr']
def __init__(self, row=None, rowptr=None, col=None, value=None,
sparse_size=None, rowcount=None, colptr=None, colcount=None,
csr2csc=None, csc2csr=None, is_sorted=False):
_row: Optional[torch.Tensor]
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_size: List[int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_size: Optional[List[int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False):
assert row is not None or rowptr is not None
assert col is not None
......@@ -75,9 +94,16 @@ class SparseStorage(object):
col = col.contiguous()
if sparse_size is None:
M = rowptr.numel() - 1 if row is None else row.max().item() + 1
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None:
M = row.max().item() + 1
else:
raise ValueError
N = col.max().item() + 1
sparse_size = torch.Size([M, N])
sparse_size = torch.Size([int(M), int(N)])
else:
assert len(sparse_size) == 2
if row is not None:
assert row.dtype == torch.long
......@@ -145,264 +171,303 @@ class SparseStorage(object):
self._csc2csr = csc2csr
if not is_sorted:
idx = self.col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row + self.col
idx = col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row() + col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row[perm]
self._col = self.col[perm]
self._value = self.value[perm] if self.has_value() else None
self._row = self.row()[perm]
self._col = col[perm]
if value is not None:
self._value = value[perm]
self._csr2csc = None
self._csc2csr = None
def has_row(self):
def has_row(self) -> bool:
return self._row is not None
@property
def row(self):
if self._row is None:
self._row = ext(self.col.is_cuda).ptr2ind(self.rowptr,
self.col.numel())
return self._row
row = self._row
if row is not None:
return row
def has_rowptr(self):
rowptr = self._rowptr
if rowptr is not None:
if rowptr.is_cuda:
row = torch.ops.torch_sparse_cuda.ptr2ind(
rowptr, self._col.numel())
else:
if rowptr.is_cuda:
row = torch.ops.torch_sparse_cuda.ptr2ind(
rowptr, self._col.numel())
else:
row = torch.ops.torch_sparse_cpu.ptr2ind(
rowptr, self._col.numel())
self._row = row
return row
raise ValueError
def has_rowptr(self) -> bool:
return self._rowptr is not None
@property
def rowptr(self):
if self._rowptr is None:
self._rowptr = ext(self.col.is_cuda).ind2ptr(
self.row, self.sparse_size[0])
return self._rowptr
def rowptr(self) -> torch.Tensor:
rowptr = self._rowptr
if rowptr is not None:
return rowptr
@property
def col(self):
row = self._row
if row is not None:
if row.is_cuda:
rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
row, self._sparse_size[0])
else:
rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
row, self._sparse_size[0])
self._rowptr = rowptr
return rowptr
raise ValueError
def col(self) -> torch.Tensor:
return self._col
def has_value(self):
def has_value(self) -> bool:
return self._value is not None
@property
def value(self):
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device)
# def set_value_(self, value, layout=None, dtype=None):
# if isinstance(value, int) or isinstance(value, float):
# value = torch.full((self.col.numel(), ), dtype=dtype,
# device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
# value = value[self.csc2csr]
if torch.is_tensor(value):
value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device
assert value.size(0) == self.col.numel()
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
self._value = value
return self
# self._value = value
# return self
def set_value(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device)
# def set_value(self, value, layout=None, dtype=None):
# if isinstance(value, int) or isinstance(value, float):
# value = torch.full((self.col.numel(), ), dtype=dtype,
# device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
# value = value[self.csc2csr]
if torch.is_tensor(value):
value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device
assert value.size(0) == self.col.numel()
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
# return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
# value=value, sparse_size=self._sparse_size,
# rowcount=self._rowcount, colptr=self._colptr,
# colcount=self._colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
@property
def sparse_size(self):
def sparse_size(self) -> List[int]:
return self._sparse_size
def sparse_resize(self, *sizes):
old_sparse_size, nnz = self.sparse_size, self.col.numel()
diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else:
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sizes[1] - old_sparse_size[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else:
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
value=self.value, sparse_size=sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self):
# def sparse_resize(self, *sizes):
# old_sparse_size, nnz = self.sparse_size, self.col.numel()
# diff_0 = sizes[0] - old_sparse_size[0]
# rowcount, rowptr = self._rowcount, self._rowptr
# if diff_0 > 0:
# if rowptr is not None:
# rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
# if rowcount is not None:
# rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
# else:
# if rowptr is not None:
# rowptr = rowptr[:-diff_0]
# if rowcount is not None:
# rowcount = rowcount[:-diff_0]
# diff_1 = sizes[1] - old_sparse_size[1]
# colcount, colptr = self._colcount, self._colptr
# if diff_1 > 0:
# if colptr is not None:
# colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
# if colcount is not None:
# colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
# else:
# if colptr is not None:
# colptr = colptr[:-diff_1]
# if colcount is not None:
# colcount = colcount[:-diff_1]
# return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
# value=self.value, sparse_size=sizes,
# rowcount=rowcount, colptr=colptr,
# colcount=colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
@cached_property
def rowcount(self):
return self.rowptr[1:] - self.rowptr[:-1]
def rowcount(self) -> torch.Tensor:
rowcount = self._rowcount
if rowcount is not None:
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[1:]
self._rowcount = rowcount
return rowcount
def has_colptr(self):
def has_colptr(self) -> bool:
return self._colptr is not None
@cached_property
def colptr(self):
if self.has_csr2csc():
return ext(self.col.is_cuda).ind2ptr(self.col[self.csr2csc],
self.sparse_size[1])
else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:])
def colptr(self) -> torch.Tensor:
colptr = self._colptr
if colptr is not None:
return colptr
def has_colcount(self):
csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse_cpu.ind2ptr(
self._col[csr2csc], self._sparse_size[1])
else:
colptr = self._col.new_zeros(self._sparse_size[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
def has_colcount(self) -> bool:
return self._colcount is not None
@cached_property
def colcount(self):
if self.has_colptr():
return self.colptr[1:] - self.colptr[:-1]
def colcount(self) -> torch.Tensor:
colcount = self._colcount
if colcount is not None:
return colcount
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[1:]
else:
return scatter_add(torch.ones_like(self.col), self.col,
dim_size=self.sparse_size[1])
raise NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
self._colcount = colcount
return colcount
def has_csr2csc(self):
def has_csr2csc(self) -> bool:
return self._csr2csc is not None
@cached_property
def csr2csc(self):
idx = self.sparse_size[0] * self.col + self.row
return idx.argsort()
def csr2csc(self) -> torch.Tensor:
csr2csc = self._csr2csc
if csr2csc is not None:
return csr2csc
def has_csc2csr(self):
idx = self._sparse_size[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
def has_csc2csr(self) -> bool:
return self._csc2csr is not None
@cached_property
def csc2csr(self):
return self.csr2csc.argsort()
def csc2csr(self) -> torch.Tensor:
csc2csr = self._csc2csr
if csc2csr is not None:
return csc2csr
def is_coalesced(self):
idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col
return (idx[1:] > idx[:-1]).all().item()
csc2csr = self.csr2csc().argsort()
self._csc2csr = csc2csr
return csc2csr
def coalesce(self, reduce='add'):
idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col
def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return self
row = self.row[mask]
col = self.col[mask]
row = self.row()[mask]
col = self._col[mask]
value = self.value
if self.has_value():
value = self._value
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
raise NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(row=row, col=col, value=value,
sparse_size=self.sparse_size, is_sorted=True)
def cached_keys(self):
return [
key for key in self.cache_keys
if getattr(self, f'_{key}', None) is not None
]
def fill_cache_(self, *args):
for arg in args or self.cache_keys + ['row', 'rowptr']:
getattr(self, arg)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=self._sparse_size, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def fill_cache_(self):
self.row()
self.rowptr()
self.rowcount()
self.colptr()
self.colcount()
self.csr2csc()
self.csc2csr()
return self
def clear_cache_(self, *args):
for arg in args or self.cache_keys:
setattr(self, f'_{arg}', None)
def clear_cache_(self):
self._rowcount = None
self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def __copy__(self):
return self.apply(lambda x: x)
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def clone(self):
return self.apply(lambda x: x.clone())
def __deepcopy__(self, memo):
new_storage = self.clone()
memo[id(self)] = new_storage
return new_storage
def apply_value_(self, func):
self._value = optional(func, self.value)
return self
def apply_value(self, func):
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value=optional(func, self.value),
sparse_size=self.sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def apply_(self, func):
self._row = optional(func, self._row)
self._rowptr = optional(func, self._rowptr)
self._col = func(self.col)
self._value = optional(func, self.value)
for key in self.cached_keys():
setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
return self
def apply(self, func):
return self.__class__(
row=optional(func, self._row),
rowptr=optional(func, self._rowptr),
col=func(self.col),
value=optional(func, self.value),
sparse_size=self.sparse_size,
rowcount=optional(func, self._rowcount),
colptr=optional(func, self._colptr),
colcount=optional(func, self._colcount),
csr2csc=optional(func, self._csr2csc),
csc2csr=optional(func, self._csc2csr),
is_sorted=True,
)
def map(self, func):
data = []
if self.has_row():
data += [func(self.row)]
if self.has_rowptr():
data += [func(self.rowptr)]
data += [func(self.col)]
if self.has_value():
data += [func(self.value)]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
return data
row = self._row
if row is not None:
row = row.clone()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.clone()
value = self._value
if value is not None:
value = value.clone()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
colptr = self._colptr
if colptr is not None:
colptr = colptr.clone()
colcount = self._colcount
if colcount is not None:
colcount = colcount.clone()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.clone()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(),
value=value, sparse_size=self._sparse_size,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def __deepcopy__(self, memo: Dict[str, Any]):
return self.clone()
from typing import Any
import torch
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so')
......@@ -14,10 +21,5 @@ except OSError as e:
raise e
def ext(is_cuda):
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
def is_scalar(other):
def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment