Commit 19fd8251 authored by limm's avatar limm
Browse files

support v0.6.16

parent 9ccee9c0
......@@ -2,6 +2,7 @@ from typing import Optional
import torch
from torch import Tensor
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
......@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor:
row, col, value = src.coo()
if value is None:
value = torch.ones(row.size(0))
value = torch.ones(row.size(0), device=row.device)
sizes = list(value.size())
sizes[0] = min(src.size(0), src.size(1))
......
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch_sparse.tensor import SparseTensor
......@@ -90,21 +91,23 @@ def spmm(src: SparseTensor, other: torch.Tensor,
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr()
value = valueA if valueA is not None else valueB
if valueA is not None and valueA.dtype == torch.half:
valueA = valueA.to(torch.float)
if valueB is not None and valueB.dtype == torch.half:
valueB = valueB.to(torch.float)
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
if valueC is not None and value is not None:
valueC = valueC.to(value.dtype)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=(M, K), is_sorted=True)
A = src.to_torch_sparse_coo_tensor()
B = other.to_torch_sparse_coo_tensor()
C = torch.sparse.mm(A, B)
edge_index = C._indices()
row, col = edge_index[0], edge_index[1]
value: Optional[Tensor] = None
if src.has_value() and other.has_value():
value = C._values()
return SparseTensor(
row=row,
col=col,
value=value,
sparse_sizes=(C.size(0), C.size(1)),
is_sorted=True,
trust_data=True,
)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
......
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
from typing import Any
import torch
import torch_scatter
from packaging import version
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
dtypes.append(torch.bfloat16)
grad_dtypes.append(torch.bfloat16)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment