"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fe5f035f797a5fa663a98030c9d0ec2f982cd09d"
Commit b0ff709e authored by rusty1s's avatar rusty1s
Browse files

torch_csr_tensor

parent fcf15650
...@@ -2,12 +2,11 @@ from itertools import product ...@@ -2,12 +2,11 @@ from itertools import product
import pytest import pytest
import torch import torch
import torch_scatter
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
import torch_scatter
from .utils import reductions, devices, grad_dtypes from .utils import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
......
from typing import Tuple from typing import Tuple
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
......
from textwrap import indent from textwrap import indent
from typing import Optional, List, Tuple, Dict, Union, Any from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
import torch
from torch_scatter import segment_csr from torch_scatter import segment_csr
from torch_sparse.storage import SparseStorage, get_layout from torch_sparse.storage import SparseStorage, get_layout
...@@ -13,14 +13,16 @@ from torch_sparse.storage import SparseStorage, get_layout ...@@ -13,14 +13,16 @@ from torch_sparse.storage import SparseStorage, get_layout
class SparseTensor(object): class SparseTensor(object):
storage: SparseStorage storage: SparseStorage
def __init__(self, row: Optional[torch.Tensor] = None, def __init__(
rowptr: Optional[torch.Tensor] = None, self,
col: Optional[torch.Tensor] = None, row: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], col: Optional[torch.Tensor] = None,
Optional[int]]] = None, value: Optional[torch.Tensor] = None,
is_sorted: bool = False, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
trust_data: bool = False): is_sorted: bool = False,
trust_data: bool = False,
):
self.storage = SparseStorage( self.storage = SparseStorage(
row=row, row=row,
rowptr=rowptr, rowptr=rowptr,
...@@ -33,7 +35,8 @@ class SparseTensor(object): ...@@ -33,7 +35,8 @@ class SparseTensor(object):
csr2csc=None, csr2csc=None,
csc2csr=None, csc2csr=None,
is_sorted=is_sorted, is_sorted=is_sorted,
trust_data=trust_data) trust_data=trust_data,
)
@classmethod @classmethod
def from_storage(self, storage: SparseStorage): def from_storage(self, storage: SparseStorage):
...@@ -44,7 +47,8 @@ class SparseTensor(object): ...@@ -44,7 +47,8 @@ class SparseTensor(object):
value=storage._value, value=storage._value,
sparse_sizes=storage._sparse_sizes, sparse_sizes=storage._sparse_sizes,
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
out.storage._rowcount = storage._rowcount out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount out.storage._colcount = storage._colcount
...@@ -53,12 +57,14 @@ class SparseTensor(object): ...@@ -53,12 +57,14 @@ class SparseTensor(object):
return out return out
@classmethod @classmethod
def from_edge_index(self, edge_index: torch.Tensor, def from_edge_index(
edge_attr: Optional[torch.Tensor] = None, self,
sparse_sizes: Optional[Tuple[Optional[int], edge_index: torch.Tensor,
Optional[int]]] = None, edge_attr: Optional[torch.Tensor] = None,
is_sorted: bool = False, sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
trust_data: bool = False): is_sorted: bool = False,
trust_data: bool = False,
):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1], return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes, value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted, trust_data=trust_data) is_sorted=is_sorted, trust_data=trust_data)
...@@ -97,6 +103,20 @@ class SparseTensor(object): ...@@ -97,6 +103,20 @@ class SparseTensor(object):
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True) is_sorted=True, trust_data=True)
@classmethod
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
has_value: bool = True):
rowptr = mat.crow_indices()
col = mat.col_indices()
value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()
return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod @classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True, def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
dtype: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[int] = None, device: Optional[torch.device] = None,
...@@ -140,7 +160,8 @@ class SparseTensor(object): ...@@ -140,7 +160,8 @@ class SparseTensor(object):
value=value, value=value,
sparse_sizes=(M, N), sparse_sizes=(M, N),
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
out.storage._rowcount = rowcount out.storage._rowcount = rowcount
out.storage._colptr = colptr out.storage._colptr = colptr
out.storage._colcount = colcount out.storage._colcount = colcount
...@@ -158,8 +179,8 @@ class SparseTensor(object): ...@@ -158,8 +179,8 @@ class SparseTensor(object):
value = self.storage.value() value = self.storage.value()
if value is None or dtype == value.dtype: if value is None or dtype == value.dtype:
return self return self
return self.from_storage(self.storage.type( return self.from_storage(
dtype=dtype, non_blocking=non_blocking)) self.storage.type(dtype=dtype, non_blocking=non_blocking))
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False): def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking) return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
...@@ -167,8 +188,8 @@ class SparseTensor(object): ...@@ -167,8 +188,8 @@ class SparseTensor(object):
def to_device(self, device: torch.device, non_blocking: bool = False): def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self.device(): if device == self.device():
return self return self
return self.from_storage(self.storage.to_device( return self.from_storage(
device=device, non_blocking=non_blocking)) self.storage.to_device(device=device, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking) return self.to_device(device=tensor.device, non_blocking=non_blocking)
...@@ -362,7 +383,8 @@ class SparseTensor(object): ...@@ -362,7 +383,8 @@ class SparseTensor(object):
value=value, value=value,
sparse_sizes=(N, N), sparse_sizes=(N, N),
is_sorted=True, is_sorted=True,
trust_data=True) trust_data=True,
)
return out return out
def detach_(self): def detach_(self):
...@@ -479,6 +501,15 @@ class SparseTensor(object): ...@@ -479,6 +501,15 @@ class SparseTensor(object):
return torch.sparse_coo_tensor(index, value, self.sizes()) return torch.sparse_coo_tensor(index, value, self.sizes())
def to_torch_sparse_csr_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
rowptr, col, value = self.csr()
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
# Python Bindings ############################################################# # Python Bindings #############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment