Commit b0ff709e authored by rusty1s's avatar rusty1s
Browse files

torch_csr_tensor

parent fcf15650
......@@ -2,12 +2,11 @@ from itertools import product
import pytest
import torch
import torch_scatter
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
import torch_scatter
from .utils import reductions, devices, grad_dtypes
from .utils import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce',
......
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
......
from textwrap import indent
from typing import Optional, List, Tuple, Dict, Union, Any
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import numpy as np
import scipy.sparse
import torch
from torch_scatter import segment_csr
from torch_sparse.storage import SparseStorage, get_layout
......@@ -13,14 +13,16 @@ from torch_sparse.storage import SparseStorage, get_layout
class SparseTensor(object):
storage: SparseStorage
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False):
def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
self.storage = SparseStorage(
row=row,
rowptr=rowptr,
......@@ -33,7 +35,8 @@ class SparseTensor(object):
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted,
trust_data=trust_data)
trust_data=trust_data,
)
@classmethod
def from_storage(self, storage: SparseStorage):
......@@ -44,7 +47,8 @@ class SparseTensor(object):
value=storage._value,
sparse_sizes=storage._sparse_sizes,
is_sorted=True,
trust_data=True)
trust_data=True,
)
out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount
......@@ -53,12 +57,14 @@ class SparseTensor(object):
return out
@classmethod
def from_edge_index(self, edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False):
def from_edge_index(
self,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted, trust_data=trust_data)
......@@ -97,6 +103,20 @@ class SparseTensor(object):
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
has_value: bool = True):
rowptr = mat.crow_indices()
col = mat.col_indices()
value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()
return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
dtype: Optional[int] = None, device: Optional[torch.device] = None,
......@@ -140,7 +160,8 @@ class SparseTensor(object):
value=value,
sparse_sizes=(M, N),
is_sorted=True,
trust_data=True)
trust_data=True,
)
out.storage._rowcount = rowcount
out.storage._colptr = colptr
out.storage._colcount = colcount
......@@ -158,8 +179,8 @@ class SparseTensor(object):
value = self.storage.value()
if value is None or dtype == value.dtype:
return self
return self.from_storage(self.storage.type(
dtype=dtype, non_blocking=non_blocking))
return self.from_storage(
self.storage.type(dtype=dtype, non_blocking=non_blocking))
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
......@@ -167,8 +188,8 @@ class SparseTensor(object):
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self.device():
return self
return self.from_storage(self.storage.to_device(
device=device, non_blocking=non_blocking))
return self.from_storage(
self.storage.to_device(device=device, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
......@@ -362,7 +383,8 @@ class SparseTensor(object):
value=value,
sparse_sizes=(N, N),
is_sorted=True,
trust_data=True)
trust_data=True,
)
return out
def detach_(self):
......@@ -479,6 +501,15 @@ class SparseTensor(object):
return torch.sparse_coo_tensor(index, value, self.sizes())
def to_torch_sparse_csr_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
rowptr, col, value = self.csr()
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
# Python Bindings #############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment