Commit 19fd8251 authored by limm's avatar limm
Browse files

support v0.6.16

parent 9ccee9c0
...@@ -2,6 +2,7 @@ from typing import Optional ...@@ -2,6 +2,7 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor: ...@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor:
row, col, value = src.coo() row, col, value = src.coo()
if value is None: if value is None:
value = torch.ones(row.size(0)) value = torch.ones(row.size(0), device=row.device)
sizes = list(value.size()) sizes = list(value.size())
sizes[0] = min(src.size(0), src.size(1)) sizes[0] = min(src.size(0), src.size(1))
......
from typing import Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor,
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0) A = src.to_torch_sparse_coo_tensor()
rowptrA, colA, valueA = src.csr() B = other.to_torch_sparse_coo_tensor()
rowptrB, colB, valueB = other.csr() C = torch.sparse.mm(A, B)
value = valueA if valueA is not None else valueB edge_index = C._indices()
if valueA is not None and valueA.dtype == torch.half: row, col = edge_index[0], edge_index[1]
valueA = valueA.to(torch.float) value: Optional[Tensor] = None
if valueB is not None and valueB.dtype == torch.half: if src.has_value() and other.has_value():
valueB = valueB.to(torch.float) value = C._values()
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( return SparseTensor(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) row=row,
if valueC is not None and value is not None: col=col,
valueC = valueC.to(value.dtype) value=value,
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, sparse_sizes=(C.size(0), C.size(1)),
sparse_sizes=(M, K), is_sorted=True) is_sorted=True,
trust_data=True,
)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
......
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
from typing import Any
import torch import torch
import torch_scatter
from packaging import version
reductions = ['sum', 'add', 'mean', 'min', 'max'] reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long] dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double] grad_dtypes = [torch.half, torch.float, torch.double]
if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
dtypes.append(torch.bfloat16)
grad_dtypes.append(torch.bfloat16)
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] devices += [torch.device('cuda:0')]
def tensor(x, dtype, device): def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment