Commit 15ff09d5 authored by rusty1s's avatar rusty1s
Browse files

removed load library calls

parent 60a29466
# flake8: noqa
import importlib import importlib
import os.path as osp import os.path as osp
...@@ -9,8 +7,9 @@ __version__ = '0.5.1' ...@@ -9,8 +7,9 @@ __version__ = '0.5.1'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']:
'_version', [osp.dirname(__file__)]).origin) torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except OSError as e: except OSError as e:
if 'undefined symbol' in str(e): if 'undefined symbol' in str(e):
major, minor = [int(x) for x in torch.__version__.split('.')[:2]] major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
...@@ -40,26 +39,27 @@ if torch.version.cuda is not None: # pragma: no cover ...@@ -40,26 +39,27 @@ if torch.version.cuda is not None: # pragma: no cover
f'{major}.{minor}. Please reinstall the torch_sparse that ' f'{major}.{minor}. Please reinstall the torch_sparse that '
f'matches your PyTorch install.') f'matches your PyTorch install.')
from .storage import SparseStorage from .storage import SparseStorage # noqa: E4402
from .tensor import SparseTensor from .tensor import SparseTensor # noqa: E4402
from .transpose import t from .transpose import t # noqa: E4402
from .narrow import narrow, __narrow_diag__ from .narrow import narrow, __narrow_diag__ # noqa: E4402
from .select import select from .select import select # noqa: E4402
from .index_select import index_select, index_select_nnz from .index_select import index_select, index_select_nnz # noqa: E4402
from .masked_select import masked_select, masked_select_nnz from .masked_select import masked_select, masked_select_nnz # noqa: E4402
from .diag import remove_diag, set_diag, fill_diag from .diag import remove_diag, set_diag, fill_diag # noqa: E4402
from .add import add, add_, add_nnz, add_nnz_ from .add import add, add_, add_nnz, add_nnz_ # noqa: E4402
from .mul import mul, mul_, mul_nnz, mul_nnz_ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa: E4402
from .reduce import sum, mean, min, max from .reduce import sum, mean, min, max # noqa: E4402
from .matmul import matmul from .matmul import matmul # noqa: E4402
from .cat import cat, cat_diag from .cat import cat, cat_diag # noqa: E4402
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy from .convert import to_torch_sparse, from_torch_sparse # noqa: E4402
from .coalesce import coalesce from .convert import to_scipy, from_scipy # noqa: E4402
from .transpose import transpose from .coalesce import coalesce # noqa: E4402
from .eye import eye from .transpose import transpose # noqa: E4402
from .spmm import spmm from .eye import eye # noqa: E4402
from .spspmm import spspmm from .spmm import spmm # noqa: E4402
from .spspmm import spspmm # noqa: E4402
__all__ = [ __all__ = [
'SparseStorage', 'SparseStorage',
......
from typing import List, Optional from typing import List
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
...@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(rowcounts) == len(tensors): if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0) rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(
sparse_sizes=sparse_sizes, rowcount=rowcount, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=rowptr,
csc2csr=None, is_sorted=True) col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim == 1: elif dim == 1:
...@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(colcounts) == len(tensors): if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0) colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, storage = SparseStorage(
sparse_sizes=sparse_sizes, rowcount=None, row=row,
colptr=colptr, colcount=colcount, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=False) col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=False)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim(): elif dim > 1 and dim < tensors[0].dim():
...@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
if len(csc2csrs) == len(tensors): if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0) csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(
sparse_sizes=sparse_sizes, rowcount=rowcount, row=row,
colptr=colptr, colcount=colcount, csr2csc=csr2csc, rowptr=rowptr,
csc2csr=csc2csr, is_sorted=True) col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
import importlib
import os.path as osp
from typing import Optional from typing import Optional
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_diag', [osp.dirname(__file__)]).origin)
@torch.jit.script @torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: ...@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone() colcount = colcount.clone()
colcount[col[mask]] -= 1 colcount[col[mask]] -= 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value, storage = SparseStorage(
sparse_sizes=src.sparse_sizes(), rowcount=rowcount, row=new_row,
colptr=None, colcount=colcount, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True) col=new_col,
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
@torch.jit.script @torch.jit.script
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor: k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k) src = remove_diag(src, k=k)
row, col, value = src.coo() row, col, value = src.coo()
...@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, ...@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
if values is not None: if values is not None:
new_value[inv_mask] = values new_value[inv_mask] = values
else: else:
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype, new_value[inv_mask] = torch.ones((num_diag, ),
dtype=value.dtype,
device=value.device) device=value.device)
rowcount = src.storage._rowcount rowcount = src.storage._rowcount
...@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, ...@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
colcount = colcount.clone() colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1 colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, storage = SparseStorage(
value=new_value, sparse_sizes=src.sparse_sizes(), row=new_row,
rowcount=rowcount, colptr=None, colcount=colcount, rowptr=None,
csr2csc=None, csc2csr=None, is_sorted=True) col=new_col,
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
......
import importlib
import os.path as osp
from typing import Union, Tuple from typing import Union, Tuple
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_spmm', [osp.dirname(__file__)]).origin)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_spspmm', [osp.dirname(__file__)]).origin)
@torch.jit.script @torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
...@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: ...@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1) M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, return SparseTensor(
sparse_sizes=torch.Size([M, K]), is_sorted=True) row=None,
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
@torch.jit.script @torch.jit.script
...@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor, ...@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError raise ValueError
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor], def matmul(src: SparseTensor,
other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"): reduce: str = "sum"):
if torch.is_tensor(other): if torch.is_tensor(other):
return spmm(src, other, reduce) return spmm(src, other, reduce)
......
import warnings import warnings
import importlib
import os.path as osp
from typing import Optional, List from typing import Optional, List
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final from torch_sparse.utils import Final
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_convert', [osp.dirname(__file__)]).origin)
layouts: Final[List[str]] = ['coo', 'csr', 'csc'] layouts: Final[List[str]] = ['coo', 'csr', 'csc']
...@@ -35,7 +30,8 @@ class SparseStorage(object): ...@@ -35,7 +30,8 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor] _csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor] _csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None, def __init__(self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
...@@ -196,7 +192,8 @@ class SparseStorage(object): ...@@ -196,7 +192,8 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]: def value(self) -> Optional[torch.Tensor]:
return self._value return self._value
def set_value_(self, value: Optional[torch.Tensor], def set_value_(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -208,7 +205,8 @@ class SparseStorage(object): ...@@ -208,7 +205,8 @@ class SparseStorage(object):
self._value = value self._value = value
return self return self
def set_value(self, value: Optional[torch.Tensor], def set_value(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -217,11 +215,18 @@ class SparseStorage(object): ...@@ -217,11 +215,18 @@ class SparseStorage(object):
assert value.device == self._col.device assert value.device == self._col.device
assert value.size(0) == self._col.numel() assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(
value=value, sparse_sizes=self._sparse_sizes, row=self._row,
rowcount=self._rowcount, colptr=self._colptr, rowptr=self._rowptr,
colcount=self._colcount, csr2csc=self._csr2csc, col=self._col,
csc2csr=self._csc2csr, is_sorted=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def sparse_sizes(self) -> List[int]: def sparse_sizes(self) -> List[int]:
return self._sparse_sizes return self._sparse_sizes
...@@ -259,11 +264,18 @@ class SparseStorage(object): ...@@ -259,11 +264,18 @@ class SparseStorage(object):
if colcount is not None: if colcount is not None:
colcount = colcount[:-diff_1] colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col, return SparseStorage(
value=self._value, sparse_sizes=sparse_sizes, row=self._row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=self._csr2csc, col=self._col,
csc2csr=self._csc2csr, is_sorted=True) value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -308,8 +320,10 @@ class SparseStorage(object): ...@@ -308,8 +320,10 @@ class SparseStorage(object):
if colptr is not None: if colptr is not None:
colcount = colptr[1:] - colptr[:-1] colcount = colptr[1:] - colptr[:-1]
else: else:
colcount = scatter_add(torch.ones_like(self._col), self._col, colcount = scatter_add(
dim_size=self._sparse_sizes[1]) torch.ones_like(self._col),
self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount self._colcount = colcount
return colcount return colcount
...@@ -361,10 +375,18 @@ class SparseStorage(object): ...@@ -361,10 +375,18 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, rowcount=None, row=row,
colptr=None, colcount=None, csr2csc=None, rowptr=None,
csc2csr=None, is_sorted=True) col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -399,12 +421,18 @@ class SparseStorage(object): ...@@ -399,12 +421,18 @@ class SparseStorage(object):
return count return count
def copy(self): def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(
value=self._value, row=self._row,
sparse_sizes=self._sparse_sizes, rowptr=self._rowptr,
rowcount=self._rowcount, colptr=self._colptr, col=self._col,
colcount=self._colcount, csr2csc=self._csr2csc, value=self._value,
csc2csr=self._csc2csr, is_sorted=True) sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -432,11 +460,18 @@ class SparseStorage(object): ...@@ -432,11 +460,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def type_as(self, tensor=torch.Tensor): def type_as(self, tensor=torch.Tensor):
value = self._value value = self._value
...@@ -477,11 +512,18 @@ class SparseStorage(object): ...@@ -477,11 +512,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking) csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -509,11 +551,18 @@ class SparseStorage(object): ...@@ -509,11 +551,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(
sparse_sizes=self._sparse_sizes, row=row,
rowcount=rowcount, colptr=colptr, rowptr=rowptr,
colcount=colcount, csr2csc=csr2csc, col=col,
csc2csr=csc2csr, is_sorted=True) value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment