"tests/vscode:/vscode.git/clone" did not exist on "86c243b45f0d1652a476d9c5ac165f22bf95c91e"
Commit f59fe649 authored by rusty1s's avatar rusty1s
Browse files

beginning of torch script support

parent c4484dbb
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.storage import SparseStorage
from typing import Dict, Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@torch.jit.script
class Foo:
rowptr: torch.Tensor
col: torch.Tensor
def __init__(self, rowptr: torch.Tensor, col: torch.Tensor):
self.rowptr = rowptr
self.col = col
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseStorage) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.rowptr(), adj.col(), None,
x, 'sum')
return out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def test_jit():
my_cell = MyCell()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x = torch.randn(3, 2)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr = torch.tensor([0, 3, 6, 9])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
adj = SparseStorage(rowptr=rowptr, col=col)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
out = traced_cell(x, adj)
print(out)
# # print(traced_cell.code)
import torch import torch
import torch_scatter import torch_scatter
from .unique import unique # from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
...@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0): ...@@ -22,6 +22,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
raise NotImplementedError
row, col = index row, col = index
......
import torch import torch
from torch_sparse.utils import ext
def remove_diag(src, k=0): def remove_diag(src, k=0):
row, col, value = src.coo() row, col, value = src.coo()
...@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0): ...@@ -39,8 +37,13 @@ def set_diag(src, values=None, k=0):
row, col, value = src.coo() row, col, value = src.coo()
mask = ext(row.is_cuda).non_diag_mask(row, col, src.size(0), src.size(1), if row.is_cuda:
k) mask = torch.ops.torch_sparse_cuda.non_diag_mask(
row, col, src.size(0), src.size(1), k)
else:
mask = torch.ops.torch_sparse_cpu.non_diag_mask(
row, col, src.size(0), src.size(1), k)
inv_mask = ~mask inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
......
import torch import torch
import scipy.sparse import scipy.sparse
from torch_scatter import scatter_add from torch_scatter import scatter_add
from torch_sparse.utils import ext
ext = None
class SPMM(torch.autograd.Function): class SPMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc, def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce): reduce):
out, arg_out = ext(mat.is_cuda).spmm(rowptr, col, value, mat, reduce) if mat.is_cuda:
out, arg_out = torch.ops.torch_sparse_cuda.spmm(
rowptr, col, value, mat, reduce)
else:
out, arg_out = torch.ops.torch_sparse_cpu.spmm(
rowptr, col, value, mat, reduce)
ctx.reduce = reduce ctx.reduce = reduce
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr, ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
......
import torch import torch
from torch_sparse import transpose, to_scipy, from_scipy, coalesce from torch_sparse import transpose, to_scipy, from_scipy, coalesce
import torch_sparse.spspmm_cpu # import torch_sparse.spspmm_cpu
if torch.cuda.is_available(): # if torch.cuda.is_available():
import torch_sparse.spspmm_cuda # import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
...@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): ...@@ -25,6 +25,7 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
raise NotImplementedError
if indexA.is_cuda and coalesced: if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k) indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n) indexB, valueB = coalesce(indexB, valueB, k, n)
......
import warnings import warnings
from typing import Optional, List, Dict, Any
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import ext from torch_sparse.utils import Final
__cache__ = {'enabled': True} __cache__ = {'enabled': True}
...@@ -32,24 +33,24 @@ class no_cache(object): ...@@ -32,24 +33,24 @@ class no_cache(object):
return decorate_no_cache return decorate_no_cache
class cached_property(object): # class cached_property(object):
def __init__(self, func): # def __init__(self, func):
self.func = func # self.func = func
def __get__(self, obj, cls): # def __get__(self, obj, cls):
value = getattr(obj, f'_{self.func.__name__}', None) # value = getattr(obj, f'_{self.func.__name__}', None)
if value is None: # if value is None:
value = self.func(obj) # value = self.func(obj)
if is_cache_enabled(): # if is_cache_enabled():
setattr(obj, f'_{self.func.__name__}', value) # setattr(obj, f'_{self.func.__name__}', value)
return value # return value
def optional(func, src): def optional(func, src):
return func(src) if src is not None else src return func(src) if src is not None else src
layouts = ['coo', 'csr', 'csc'] layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout=None): def get_layout(layout=None):
...@@ -61,12 +62,30 @@ def get_layout(layout=None): ...@@ -61,12 +62,30 @@ def get_layout(layout=None):
return layout return layout
@torch.jit.script
class SparseStorage(object): class SparseStorage(object):
cache_keys = ['rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'] _row: Optional[torch.Tensor]
_rowptr: Optional[torch.Tensor]
def __init__(self, row=None, rowptr=None, col=None, value=None, _col: torch.Tensor
sparse_size=None, rowcount=None, colptr=None, colcount=None, _value: Optional[torch.Tensor]
csr2csc=None, csc2csr=None, is_sorted=False): _sparse_size: List[int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_size: Optional[List[int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False):
assert row is not None or rowptr is not None assert row is not None or rowptr is not None
assert col is not None assert col is not None
...@@ -75,9 +94,16 @@ class SparseStorage(object): ...@@ -75,9 +94,16 @@ class SparseStorage(object):
col = col.contiguous() col = col.contiguous()
if sparse_size is None: if sparse_size is None:
M = rowptr.numel() - 1 if row is None else row.max().item() + 1 if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None:
M = row.max().item() + 1
else:
raise ValueError
N = col.max().item() + 1 N = col.max().item() + 1
sparse_size = torch.Size([M, N]) sparse_size = torch.Size([int(M), int(N)])
else:
assert len(sparse_size) == 2
if row is not None: if row is not None:
assert row.dtype == torch.long assert row.dtype == torch.long
...@@ -145,264 +171,303 @@ class SparseStorage(object): ...@@ -145,264 +171,303 @@ class SparseStorage(object):
self._csc2csr = csc2csr self._csc2csr = csc2csr
if not is_sorted: if not is_sorted:
idx = self.col.new_zeros(col.numel() + 1) idx = col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row + self.col idx[1:] = sparse_size[1] * self.row() + col
if (idx[1:] < idx[:-1]).any(): if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort() perm = idx[1:].argsort()
self._row = self.row[perm] self._row = self.row()[perm]
self._col = self.col[perm] self._col = col[perm]
self._value = self.value[perm] if self.has_value() else None if value is not None:
self._value = value[perm]
self._csr2csc = None self._csr2csc = None
self._csc2csr = None self._csc2csr = None
def has_row(self): def has_row(self) -> bool:
return self._row is not None return self._row is not None
@property
def row(self): def row(self):
if self._row is None: row = self._row
self._row = ext(self.col.is_cuda).ptr2ind(self.rowptr, if row is not None:
self.col.numel()) return row
return self._row
def has_rowptr(self): rowptr = self._rowptr
if rowptr is not None:
if rowptr.is_cuda:
row = torch.ops.torch_sparse_cuda.ptr2ind(
rowptr, self._col.numel())
else:
if rowptr.is_cuda:
row = torch.ops.torch_sparse_cuda.ptr2ind(
rowptr, self._col.numel())
else:
row = torch.ops.torch_sparse_cpu.ptr2ind(
rowptr, self._col.numel())
self._row = row
return row
raise ValueError
def has_rowptr(self) -> bool:
return self._rowptr is not None return self._rowptr is not None
@property def rowptr(self) -> torch.Tensor:
def rowptr(self): rowptr = self._rowptr
if self._rowptr is None: if rowptr is not None:
self._rowptr = ext(self.col.is_cuda).ind2ptr( return rowptr
self.row, self.sparse_size[0])
return self._rowptr
@property row = self._row
def col(self): if row is not None:
if row.is_cuda:
rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
row, self._sparse_size[0])
else:
rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
row, self._sparse_size[0])
self._rowptr = rowptr
return rowptr
raise ValueError
def col(self) -> torch.Tensor:
return self._col return self._col
def has_value(self): def has_value(self) -> bool:
return self._value is not None return self._value is not None
@property def value(self) -> Optional[torch.Tensor]:
def value(self):
return self._value return self._value
def set_value_(self, value, layout=None, dtype=None): # def set_value_(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float): # if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.col.numel(), ), dtype=dtype, # value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device) # device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': # elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr] # value = value[self.csc2csr]
if torch.is_tensor(value): # if torch.is_tensor(value):
value = value if dtype is None else value.to(dtype) # value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device # assert value.device == self.col.device
assert value.size(0) == self.col.numel() # assert value.size(0) == self.col.numel()
self._value = value # self._value = value
return self # return self
def set_value(self, value, layout=None, dtype=None): # def set_value(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float): # if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.col.numel(), ), dtype=dtype, # value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device) # device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': # elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr] # value = value[self.csc2csr]
if torch.is_tensor(value): # if torch.is_tensor(value):
value = value if dtype is None else value.to(dtype) # value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device # assert value.device == self.col.device
assert value.size(0) == self.col.numel() # assert value.size(0) == self.col.numel()
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col, # return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value=value, sparse_size=self._sparse_size, # value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr, # rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc, # colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) # csc2csr=self._csc2csr, is_sorted=True)
@property def sparse_size(self) -> List[int]:
def sparse_size(self):
return self._sparse_size return self._sparse_size
def sparse_resize(self, *sizes): # def sparse_resize(self, *sizes):
old_sparse_size, nnz = self.sparse_size, self.col.numel() # old_sparse_size, nnz = self.sparse_size, self.col.numel()
diff_0 = sizes[0] - old_sparse_size[0] # diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr # rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0: # if diff_0 > 0:
if rowptr is not None: # if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)]) # rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None: # if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)]) # rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else: # else:
if rowptr is not None: # if rowptr is not None:
rowptr = rowptr[:-diff_0] # rowptr = rowptr[:-diff_0]
if rowcount is not None: # if rowcount is not None:
rowcount = rowcount[:-diff_0] # rowcount = rowcount[:-diff_0]
diff_1 = sizes[1] - old_sparse_size[1] # diff_1 = sizes[1] - old_sparse_size[1]
colcount, colptr = self._colcount, self._colptr # colcount, colptr = self._colcount, self._colptr
if diff_1 > 0: # if diff_1 > 0:
if colptr is not None: # if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)]) # colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None: # if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)]) # colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else: # else:
if colptr is not None: # if colptr is not None:
colptr = colptr[:-diff_1] # colptr = colptr[:-diff_1]
if colcount is not None: # if colcount is not None:
colcount = colcount[:-diff_1] # colcount = colcount[:-diff_1]
return self.__class__(row=self._row, rowptr=rowptr, col=self.col, # return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
value=self.value, sparse_size=sizes, # value=self.value, sparse_size=sizes,
rowcount=rowcount, colptr=colptr, # rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc, # colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) # csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self): def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
@cached_property def rowcount(self) -> torch.Tensor:
def rowcount(self): rowcount = self._rowcount
return self.rowptr[1:] - self.rowptr[:-1] if rowcount is not None:
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[1:]
self._rowcount = rowcount
return rowcount
def has_colptr(self): def has_colptr(self) -> bool:
return self._colptr is not None return self._colptr is not None
@cached_property def colptr(self) -> torch.Tensor:
def colptr(self): colptr = self._colptr
if self.has_csr2csc(): if colptr is not None:
return ext(self.col.is_cuda).ind2ptr(self.col[self.csr2csc],
self.sparse_size[1])
else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:])
return colptr return colptr
def has_colcount(self): csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse_cpu.ind2ptr(
self._col[csr2csc], self._sparse_size[1])
else:
colptr = self._col.new_zeros(self._sparse_size[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
def has_colcount(self) -> bool:
return self._colcount is not None return self._colcount is not None
@cached_property def colcount(self) -> torch.Tensor:
def colcount(self): colcount = self._colcount
if self.has_colptr(): if colcount is not None:
return self.colptr[1:] - self.colptr[:-1] return colcount
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[1:]
else: else:
return scatter_add(torch.ones_like(self.col), self.col, raise NotImplementedError
dim_size=self.sparse_size[1]) # colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
self._colcount = colcount
return colcount
def has_csr2csc(self): def has_csr2csc(self) -> bool:
return self._csr2csc is not None return self._csr2csc is not None
@cached_property def csr2csc(self) -> torch.Tensor:
def csr2csc(self): csr2csc = self._csr2csc
idx = self.sparse_size[0] * self.col + self.row if csr2csc is not None:
return idx.argsort() return csr2csc
def has_csc2csr(self): idx = self._sparse_size[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
def has_csc2csr(self) -> bool:
return self._csc2csr is not None return self._csc2csr is not None
@cached_property def csc2csr(self) -> torch.Tensor:
def csc2csr(self): csc2csr = self._csc2csr
return self.csr2csc.argsort() if csc2csr is not None:
return csc2csr
def is_coalesced(self): csc2csr = self.csr2csc().argsort()
idx = self.col.new_full((self.col.numel() + 1, ), -1) self._csc2csr = csc2csr
idx[1:] = self.sparse_size[1] * self.row + self.col return csc2csr
return (idx[1:] > idx[:-1]).all().item()
def coalesce(self, reduce='add'): def is_coalesced(self) -> bool:
idx = self.col.new_full((self.col.numel() + 1, ), -1) idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col idx[1:] = self._sparse_size[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col
mask = idx[1:] > idx[:-1] mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced. if mask.all(): # Skip if indices are already coalesced.
return self return self
row = self.row[mask] row = self.row()[mask]
col = self.col[mask] col = self._col[mask]
value = self.value value = self._value
if self.has_value(): if value is not None:
ptr = mask.nonzero().flatten() ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce) raise NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return self.__class__(row=row, col=col, value=value, return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=self.sparse_size, is_sorted=True) sparse_size=self._sparse_size, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
def cached_keys(self): csc2csr=None, is_sorted=True)
return [
key for key in self.cache_keys def fill_cache_(self):
if getattr(self, f'_{key}', None) is not None self.row()
] self.rowptr()
self.rowcount()
def fill_cache_(self, *args): self.colptr()
for arg in args or self.cache_keys + ['row', 'rowptr']: self.colcount()
getattr(self, arg) self.csr2csc()
self.csc2csr()
return self return self
def clear_cache_(self, *args): def clear_cache_(self):
for arg in args or self.cache_keys: self._rowcount = None
setattr(self, f'_{arg}', None) self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self return self
def __copy__(self): def __copy__(self):
return self.apply(lambda x: x) return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def clone(self): def clone(self):
return self.apply(lambda x: x.clone()) row = self._row
if row is not None:
def __deepcopy__(self, memo): row = row.clone()
new_storage = self.clone() rowptr = self._rowptr
memo[id(self)] = new_storage if rowptr is not None:
return new_storage rowptr = rowptr.clone()
value = self._value
def apply_value_(self, func): if value is not None:
self._value = optional(func, self.value) value = value.clone()
return self rowcount = self._rowcount
if rowcount is not None:
def apply_value(self, func): rowcount = rowcount.clone()
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col, colptr = self._colptr
value=optional(func, self.value), if colptr is not None:
sparse_size=self.sparse_size, colptr = colptr.clone()
rowcount=self._rowcount, colptr=self._colptr, colcount = self._colcount
colcount=self._colcount, csr2csc=self._csr2csc, if colcount is not None:
csc2csr=self._csc2csr, is_sorted=True) colcount = colcount.clone()
csr2csc = self._csr2csc
def apply_(self, func): if csr2csc is not None:
self._row = optional(func, self._row) csr2csc = csr2csc.clone()
self._rowptr = optional(func, self._rowptr) csc2csr = self._csc2csr
self._col = func(self.col) if csc2csr is not None:
self._value = optional(func, self.value) csc2csr = csc2csr.clone()
for key in self.cached_keys(): return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(),
setattr(self, f'_{key}', func(getattr(self, f'_{key}'))) value=value, sparse_size=self._sparse_size,
return self rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
def apply(self, func): csc2csr=csc2csr, is_sorted=True)
return self.__class__(
row=optional(func, self._row), def __deepcopy__(self, memo: Dict[str, Any]):
rowptr=optional(func, self._rowptr), return self.clone()
col=func(self.col),
value=optional(func, self.value),
sparse_size=self.sparse_size,
rowcount=optional(func, self._rowcount),
colptr=optional(func, self._colptr),
colcount=optional(func, self._colcount),
csr2csc=optional(func, self._csr2csc),
csc2csr=optional(func, self._csc2csr),
is_sorted=True,
)
def map(self, func):
data = []
if self.has_row():
data += [func(self.row)]
if self.has_rowptr():
data += [func(self.rowptr)]
data += [func(self.col)]
if self.has_value():
data += [func(self.value)]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
return data
from typing import Any
import torch import torch
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
torch.ops.load_library('torch_sparse/convert_cpu.so') torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so') torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so') torch.ops.load_library('torch_sparse/spmm_cpu.so')
...@@ -14,10 +21,5 @@ except OSError as e: ...@@ -14,10 +21,5 @@ except OSError as e:
raise e raise e
def ext(is_cuda): def is_scalar(other: Any) -> bool:
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float) return isinstance(other, int) or isinstance(other, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment