Commit 15ff09d5 authored by rusty1s's avatar rusty1s
Browse files

removed load library calls

parent 60a29466
# flake8: noqa
import importlib
import os.path as osp
......@@ -9,8 +7,9 @@ __version__ = '0.5.1'
expected_torch_version = (1, 4)
try:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_version', [osp.dirname(__file__)]).origin)
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except OSError as e:
if 'undefined symbol' in str(e):
major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
......@@ -40,26 +39,27 @@ if torch.version.cuda is not None: # pragma: no cover
f'{major}.{minor}. Please reinstall the torch_sparse that '
f'matches your PyTorch install.')
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow, __narrow_diag__
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import remove_diag, set_diag, fill_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
from .matmul import matmul
from .cat import cat, cat_diag
from .storage import SparseStorage # noqa: E4402
from .tensor import SparseTensor # noqa: E4402
from .transpose import t # noqa: E4402
from .narrow import narrow, __narrow_diag__ # noqa: E4402
from .select import select # noqa: E4402
from .index_select import index_select, index_select_nnz # noqa: E4402
from .masked_select import masked_select, masked_select_nnz # noqa: E4402
from .diag import remove_diag, set_diag, fill_diag # noqa: E4402
from .add import add, add_, add_nnz, add_nnz_ # noqa: E4402
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa: E4402
from .reduce import sum, mean, min, max # noqa: E4402
from .matmul import matmul # noqa: E4402
from .cat import cat, cat_diag # noqa: E4402
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce
from .transpose import transpose
from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
from .convert import to_torch_sparse, from_torch_sparse # noqa: E4402
from .convert import to_scipy, from_scipy # noqa: E4402
from .coalesce import coalesce # noqa: E4402
from .transpose import transpose # noqa: E4402
from .eye import eye # noqa: E4402
from .spmm import spmm # noqa: E4402
from .spspmm import spspmm # noqa: E4402
__all__ = [
'SparseStorage',
......
from typing import List, Optional
from typing import List
import torch
from torch_sparse.storage import SparseStorage
......@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1:
......@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
storage = SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=False)
return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim():
......@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return tensors[0].from_storage(storage)
import importlib
import os.path as osp
from typing import Optional
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_diag', [osp.dirname(__file__)]).origin)
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
......@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone()
colcount[col[mask]] -= 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k)
row, col, value = src.coo()
......@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
if values is not None:
new_value[inv_mask] = values
else:
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
new_value[inv_mask] = torch.ones((num_diag, ),
dtype=value.dtype,
device=value.device)
rowcount = src.storage._rowcount
......@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colptr=None, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=True)
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage)
......
import importlib
import os.path as osp
from typing import Union, Tuple
import torch
from torch_sparse.tensor import SparseTensor
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_spmm', [osp.dirname(__file__)]).origin)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_spspmm', [osp.dirname(__file__)]).origin)
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
......@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=torch.Size([M, K]), is_sorted=True)
return SparseTensor(
row=None,
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
@torch.jit.script
......@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
def matmul(src: SparseTensor,
other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"):
if torch.is_tensor(other):
return spmm(src, other, reduce)
......
import warnings
import importlib
import os.path as osp
from typing import Optional, List
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
'_convert', [osp.dirname(__file__)]).origin)
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
......@@ -35,7 +30,8 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
def __init__(self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
......@@ -196,7 +192,8 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self, value: Optional[torch.Tensor],
def set_value_(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -208,7 +205,8 @@ class SparseStorage(object):
self._value = value
return self
def set_value(self, value: Optional[torch.Tensor],
def set_value(self,
value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -217,11 +215,18 @@ class SparseStorage(object):
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def sparse_sizes(self) -> List[int]:
return self._sparse_sizes
......@@ -259,11 +264,18 @@ class SparseStorage(object):
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -308,8 +320,10 @@ class SparseStorage(object):
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
colcount = scatter_add(
torch.ones_like(self._col),
self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
......@@ -361,10 +375,18 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
def fill_cache_(self):
self.row()
......@@ -399,12 +421,18 @@ class SparseStorage(object):
return count
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def clone(self):
row = self._row
......@@ -432,11 +460,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
......@@ -477,11 +512,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def pin_memory(self):
row = self._row
......@@ -509,11 +551,18 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment