Unverified Commit fa74ea70 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Polish diag_matrix.py and sparse_matrix.py. (#5178)



* polish diagonal

* polsih

* more

* more

* add paramaters-back
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 0868f179
...@@ -8,22 +8,15 @@ from .sparse_matrix import from_coo, SparseMatrix ...@@ -8,22 +8,15 @@ from .sparse_matrix import from_coo, SparseMatrix
class DiagMatrix: class DiagMatrix:
"""Diagonal Matrix Class r"""Class for diagonal matrix.
Parameters Parameters
---------- ----------
val : torch.Tensor val : torch.Tensor
Diagonal of the matrix. It can take shape (N) or (N, D). Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``
shape : tuple[int, int], optional shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`val`, i.e., If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,
(N, N). Otherwise, :attr:`len(val)` must be equal to :attr:`min(shape)`. otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``
Attributes
----------
val : torch.Tensor
Diagonal of the matrix.
shape : tuple[int, int]
Shape of the matrix.
""" """
def __init__( def __init__(
...@@ -40,71 +33,71 @@ class DiagMatrix: ...@@ -40,71 +33,71 @@ class DiagMatrix:
self._val = val self._val = val
self._shape = shape self._shape = shape
def __repr__(self):
return _diag_matrix_str(self)
@property @property
def val(self) -> torch.Tensor: def val(self) -> torch.Tensor:
"""Get the values of the nonzero elements. """Returns the values of the non-zero elements.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Values of the nonzero elements Values of the non-zero elements
""" """
return self._val return self._val
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
"""Shape of the sparse matrix. """Returns the shape of the diagonal matrix.
Returns Returns
------- -------
Tuple[int] Tuple[int]
The shape of the matrix The shape of the diagonal matrix
""" """
return self._shape return self._shape
def __repr__(self):
return _diag_matrix_str(self)
@property @property
def nnz(self) -> int: def nnz(self) -> int:
"""Return the number of non-zero values in the matrix """Returns the number of non-zero elements in the diagonal matrix.
Returns Returns
------- -------
int int
The number of non-zero values in the matrix The number of non-zero elements in the diagonal matrix
""" """
return self.val.shape[0] return self.val.shape[0]
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
"""Return the data type of the matrix """Returns the data type of the diagonal matrix.
Returns Returns
------- -------
torch.dtype torch.dtype
Data type of the matrix Data type of the diagonal matrix
""" """
return self.val.dtype return self.val.dtype
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
"""Return the device of the matrix """Returns the device the diagonal matrix is on.
Returns Returns
------- -------
torch.device torch.device
Device of the matrix The device the diagonal matrix is on
""" """
return self.val.device return self.val.device
def to_sparse(self) -> SparseMatrix: def to_sparse(self) -> SparseMatrix:
"""Convert the diagonal matrix into a sparse matrix object """Returns a copy in sparse matrix format of the diagonal matrix.
Returns Returns
------- -------
SparseMatrix SparseMatrix
The converted sparse matrix object The copy in sparse matrix format
Example Example
------- -------
...@@ -123,12 +116,12 @@ class DiagMatrix: ...@@ -123,12 +116,12 @@ class DiagMatrix:
return from_coo(row=row, col=col, val=self.val, shape=self.shape) return from_coo(row=row, col=col, val=self.val, shape=self.shape)
def to_dense(self) -> torch.Tensor: def to_dense(self) -> torch.Tensor:
"""Return a dense representation of the matrix. """Returns a copy in dense matrix format of the diagonal matrix.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Dense representation of the diagonal matrix. The copy in dense matrix format
""" """
val = self.val val = self.val
device = self.device device = self.device
...@@ -148,12 +141,12 @@ class DiagMatrix: ...@@ -148,12 +141,12 @@ class DiagMatrix:
return self.transpose() return self.transpose()
def transpose(self): def transpose(self):
"""Return the transpose of the matrix. """Returns a matrix that is a transposed version of the diagonal matrix.
Returns Returns
------- -------
DiagMatrix DiagMatrix
The transpose of the matrix. The transpose of the matrix
Example Example
-------- --------
...@@ -168,22 +161,22 @@ class DiagMatrix: ...@@ -168,22 +161,22 @@ class DiagMatrix:
return DiagMatrix(self.val, self.shape[::-1]) return DiagMatrix(self.val, self.shape[::-1])
def to(self, device=None, dtype=None): def to(self, device=None, dtype=None):
"""Perform matrix dtype and/or device conversion. If the target device """Performs matrix dtype and/or device conversion. If the target device
and dtype are already in use, the original matrix will be returned. and dtype are already in use, the original matrix will be returned.
Parameters Parameters
---------- ----------
device : torch.device, optional device : torch.device, optional
The target device of the matrix if given, otherwise the current The target device of the matrix if provided, otherwise the current
device will be used device will be used
dtype : torch.dtype, optional dtype : torch.dtype, optional
The target data type of the matrix values if given, otherwise the The target data type of the matrix values if provided, otherwise the
current data type will be used current data type will be used
Returns Returns
------- -------
DiagMatrix DiagMatrix
The result matrix The converted matrix
Example Example
-------- --------
...@@ -205,7 +198,7 @@ class DiagMatrix: ...@@ -205,7 +198,7 @@ class DiagMatrix:
return diag(self.val.to(device=device, dtype=dtype), self.shape) return diag(self.val.to(device=device, dtype=dtype), self.shape)
def cuda(self): def cuda(self):
"""Move the matrix to GPU. If the matrix is already on GPU, the """Moves the matrix to GPU. If the matrix is already on GPU, the
original matrix will be returned. If multiple GPU devices exist, original matrix will be returned. If multiple GPU devices exist,
'cuda:0' will be selected. 'cuda:0' will be selected.
...@@ -226,7 +219,7 @@ class DiagMatrix: ...@@ -226,7 +219,7 @@ class DiagMatrix:
return self.to(device="cuda") return self.to(device="cuda")
def cpu(self): def cpu(self):
"""Move the matrix to CPU. If the matrix is already on CPU, the """Moves the matrix to CPU. If the matrix is already on CPU, the
original matrix will be returned. original matrix will be returned.
Returns Returns
...@@ -246,7 +239,7 @@ class DiagMatrix: ...@@ -246,7 +239,7 @@ class DiagMatrix:
return self.to(device="cpu") return self.to(device="cpu")
def float(self): def float(self):
"""Convert the matrix values to float data type. If the matrix already """Converts the matrix values to float data type. If the matrix already
uses float data type, the original matrix will be returned. uses float data type, the original matrix will be returned.
Returns Returns
...@@ -266,7 +259,7 @@ class DiagMatrix: ...@@ -266,7 +259,7 @@ class DiagMatrix:
return self.to(dtype=torch.float) return self.to(dtype=torch.float)
def double(self): def double(self):
"""Convert the matrix values to double data type. If the matrix already """Converts the matrix values to double data type. If the matrix already
uses double data type, the original matrix will be returned. uses double data type, the original matrix will be returned.
Returns Returns
...@@ -286,7 +279,7 @@ class DiagMatrix: ...@@ -286,7 +279,7 @@ class DiagMatrix:
return self.to(dtype=torch.double) return self.to(dtype=torch.double)
def int(self): def int(self):
"""Convert the matrix values to int data type. If the matrix already """Converts the matrix values to int data type. If the matrix already
uses int data type, the original matrix will be returned. uses int data type, the original matrix will be returned.
Returns Returns
...@@ -306,7 +299,7 @@ class DiagMatrix: ...@@ -306,7 +299,7 @@ class DiagMatrix:
return self.to(dtype=torch.int) return self.to(dtype=torch.int)
def long(self): def long(self):
"""Convert the matrix values to long data type. If the matrix already """Converts the matrix values to long data type. If the matrix already
uses long data type, the original matrix will be returned. uses long data type, the original matrix will be returned.
Returns Returns
...@@ -329,15 +322,15 @@ class DiagMatrix: ...@@ -329,15 +322,15 @@ class DiagMatrix:
def diag( def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
) -> DiagMatrix: ) -> DiagMatrix:
"""Create a diagonal matrix based on the diagonal values """Creates a diagonal matrix based on the diagonal values.
Parameters Parameters
---------- ----------
val : torch.Tensor val : torch.Tensor
Diagonal of the matrix. It can take shape (N) or (N, D). Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``
shape : tuple[int, int], optional shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`val`, i.e., If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,
(N, N). Otherwise, :attr:`len(val)` must be equal to :attr:`min(shape)`. otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``
Returns Returns
------- -------
...@@ -383,7 +376,7 @@ def identity( ...@@ -383,7 +376,7 @@ def identity(
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> DiagMatrix: ) -> DiagMatrix:
"""Create a diagonal matrix with ones on the diagonal and zeros elsewhere """Creates a diagonal matrix with ones on the diagonal and zeros elsewhere.
Parameters Parameters
---------- ----------
...@@ -447,7 +440,8 @@ def identity( ...@@ -447,7 +440,8 @@ def identity(
def _diag_matrix_str(spmat: DiagMatrix) -> str: def _diag_matrix_str(spmat: DiagMatrix) -> str:
"""Internal function for converting a diagonal matrix to string """Internal function for converting a diagonal matrix to string
representation.""" representation.
"""
values_str = str(spmat.val) values_str = str(spmat.val)
meta_str = f"size={spmat.shape}" meta_str = f"size={spmat.shape}"
if spmat.val.dim() > 1: if spmat.val.dim() > 1:
......
...@@ -11,128 +11,127 @@ class SparseMatrix: ...@@ -11,128 +11,127 @@ class SparseMatrix:
def __init__(self, c_sparse_matrix: torch.ScriptObject): def __init__(self, c_sparse_matrix: torch.ScriptObject):
self.c_sparse_matrix = c_sparse_matrix self.c_sparse_matrix = c_sparse_matrix
def __repr__(self):
return _sparse_matrix_str(self)
@property @property
def val(self) -> torch.Tensor: def val(self) -> torch.Tensor:
"""Get the values of the nonzero elements. """Returns the values of the non-zero elements.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Values of the nonzero elements Values of the non-zero elements
""" """
return self.c_sparse_matrix.val() return self.c_sparse_matrix.val()
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
"""Shape of the sparse matrix. """Returns the shape of the sparse matrix.
Returns Returns
------- -------
Tuple[int] Tuple[int]
The shape of the matrix The shape of the sparse matrix
""" """
return tuple(self.c_sparse_matrix.shape()) return tuple(self.c_sparse_matrix.shape())
@property @property
def nnz(self) -> int: def nnz(self) -> int:
"""The number of nonzero elements of the sparse matrix. """Returns the number of non-zero elements in the sparse matrix.
Returns Returns
------- -------
int int
The number of nonzero elements of the matrix The number of non-zero elements of the matrix
""" """
return self.c_sparse_matrix.nnz() return self.c_sparse_matrix.nnz()
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
"""Data type of the values of the sparse matrix. """Returns the data type of the sparse matrix.
Returns Returns
------- -------
torch.dtype torch.dtype
Data type of the values of the matrix Data type of the sparse matrix
""" """
# FIXME: find a proper way to pass dtype from C++ to Python
return self.c_sparse_matrix.val().dtype return self.c_sparse_matrix.val().dtype
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
"""Device of the sparse matrix. """Returns the device the sparse matrix is on.
Returns Returns
------- -------
torch.device torch.device
Device of the matrix The device the sparse matrix is on
""" """
return self.c_sparse_matrix.device() return self.c_sparse_matrix.device()
@property @property
def row(self) -> torch.Tensor: def row(self) -> torch.Tensor:
"""Get the row indices of the nonzero elements. """Returns the row indices of the non-zero elements.
Returns Returns
------- -------
tensor tensor
Row indices of the nonzero elements Row indices of the non-zero elements
""" """
return self.coo()[0] return self.coo()[0]
@property @property
def col(self) -> torch.Tensor: def col(self) -> torch.Tensor:
"""Get the column indices of the nonzero elements. """Returns the column indices of the non-zero elements.
Returns Returns
------- -------
tensor tensor
Column indices of the nonzero elements Column indices of the non-zero elements
""" """
return self.coo()[1] return self.coo()[1]
def __repr__(self):
return _sparse_matrix_str(self)
def coo(self) -> Tuple[torch.Tensor, ...]: def coo(self) -> Tuple[torch.Tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix. """Returns the coordinate (COO) representation of the sparse matrix.
Returns Returns
------- -------
Tuple[torch.Tensor, torch.Tensor] Tuple[torch.Tensor, torch.Tensor]
A tuple of tensors containing row and column coordinates. A tuple of tensors containing row and column coordinates
""" """
return self.c_sparse_matrix.coo() return self.c_sparse_matrix.coo()
def csr(self) -> Tuple[torch.Tensor, ...]: def csr(self) -> Tuple[torch.Tensor, ...]:
r"""Get the compressed sparse row (CSR) representation of the sparse r"""Returns the compressed sparse row (CSR) representation of the sparse
matrix. matrix.
Returns Returns
------- -------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and value A tuple of tensors containing row, column coordinates and value
indices. indices
""" """
return self.c_sparse_matrix.csr() return self.c_sparse_matrix.csr()
def csc(self) -> Tuple[torch.Tensor, ...]: def csc(self) -> Tuple[torch.Tensor, ...]:
r"""Get the compressed sparse column (CSC) representation of the sparse r"""Returns the compressed sparse column (CSC) representation of the
matrix. sparse matrix.
Returns Returns
------- -------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and value A tuple of tensors containing row, column coordinates and value
indices. indices
""" """
return self.c_sparse_matrix.csc() return self.c_sparse_matrix.csc()
def to_dense(self) -> torch.Tensor: def to_dense(self) -> torch.Tensor:
"""Return a dense representation of the matrix. """Returns a copy in dense matrix format of the sparse matrix.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Dense representation of the sparse matrix. The copy in dense matrix format
""" """
row, col = self.coo() row, col = self.coo()
val = self.val val = self.val
...@@ -175,7 +174,7 @@ class SparseMatrix: ...@@ -175,7 +174,7 @@ class SparseMatrix:
return SparseMatrix(self.c_sparse_matrix.transpose()) return SparseMatrix(self.c_sparse_matrix.transpose())
def to(self, device=None, dtype=None): def to(self, device=None, dtype=None):
"""Perform matrix dtype and/or device conversion. If the target device """Performs matrix dtype and/or device conversion. If the target device
and dtype are already in use, the original matrix will be returned. and dtype are already in use, the original matrix will be returned.
Parameters Parameters
...@@ -184,13 +183,13 @@ class SparseMatrix: ...@@ -184,13 +183,13 @@ class SparseMatrix:
The target device of the matrix if provided, otherwise the current The target device of the matrix if provided, otherwise the current
device will be used device will be used
dtype : torch.dtype, optional dtype : torch.dtype, optional
The target data type of the matrix values, otherwise the current The target data type of the matrix values if provided, otherwise the
data type will be used current data type will be used
Returns Returns
------- -------
SparseMatrix SparseMatrix
The result matrix The converted matrix
Example Example
-------- --------
...@@ -224,7 +223,7 @@ class SparseMatrix: ...@@ -224,7 +223,7 @@ class SparseMatrix:
return from_coo(row, col, val, self.shape) return from_coo(row, col, val, self.shape)
def cuda(self): def cuda(self):
"""Move the matrix to GPU. If the matrix is already on GPU, the """Moves the matrix to GPU. If the matrix is already on GPU, the
original matrix will be returned. If multiple GPU devices exist, original matrix will be returned. If multiple GPU devices exist,
'cuda:0' will be selected. 'cuda:0' will be selected.
...@@ -248,7 +247,7 @@ class SparseMatrix: ...@@ -248,7 +247,7 @@ class SparseMatrix:
return self.to(device="cuda") return self.to(device="cuda")
def cpu(self): def cpu(self):
"""Move the matrix to CPU. If the matrix is already on CPU, the """Moves the matrix to CPU. If the matrix is already on CPU, the
original matrix will be returned. original matrix will be returned.
Returns Returns
...@@ -271,7 +270,7 @@ class SparseMatrix: ...@@ -271,7 +270,7 @@ class SparseMatrix:
return self.to(device="cpu") return self.to(device="cpu")
def float(self): def float(self):
"""Convert the matrix values to float data type. If the matrix already """Converts the matrix values to float data type. If the matrix already
uses float data type, the original matrix will be returned. uses float data type, the original matrix will be returned.
Returns Returns
...@@ -295,7 +294,7 @@ class SparseMatrix: ...@@ -295,7 +294,7 @@ class SparseMatrix:
return self.to(dtype=torch.float) return self.to(dtype=torch.float)
def double(self): def double(self):
"""Convert the matrix values to double data type. If the matrix already """Converts the matrix values to double data type. If the matrix already
uses double data type, the original matrix will be returned. uses double data type, the original matrix will be returned.
Returns Returns
...@@ -318,7 +317,7 @@ class SparseMatrix: ...@@ -318,7 +317,7 @@ class SparseMatrix:
return self.to(dtype=torch.double) return self.to(dtype=torch.double)
def int(self): def int(self):
"""Convert the matrix values to int data type. If the matrix already """Converts the matrix values to int data type. If the matrix already
uses int data type, the original matrix will be returned. uses int data type, the original matrix will be returned.
Returns Returns
...@@ -341,7 +340,7 @@ class SparseMatrix: ...@@ -341,7 +340,7 @@ class SparseMatrix:
return self.to(dtype=torch.int) return self.to(dtype=torch.int)
def long(self): def long(self):
"""Convert the matrix values to long data type. If the matrix already """Converts the matrix values to long data type. If the matrix already
uses long data type, the original matrix will be returned. uses long data type, the original matrix will be returned.
Returns Returns
...@@ -364,14 +363,14 @@ class SparseMatrix: ...@@ -364,14 +363,14 @@ class SparseMatrix:
return self.to(dtype=torch.long) return self.to(dtype=torch.long)
def coalesce(self): def coalesce(self):
"""Return a coalesced sparse matrix. """Returns a coalesced sparse matrix.
A coalesced sparse matrix satisfies the following properties: A coalesced sparse matrix satisfies the following properties:
- the indices of the non-zero elements are unique, - the indices of the non-zero elements are unique,
- the indices are sorted in lexicographical order. - the indices are sorted in lexicographical order.
The coalescing process will accumulate the non-zero values of the same The coalescing process will accumulate the non-zero elements of the same
indices by summation. indices by summation.
The function does not support autograd. The function does not support autograd.
...@@ -379,7 +378,7 @@ class SparseMatrix: ...@@ -379,7 +378,7 @@ class SparseMatrix:
Returns Returns
------- -------
SparseMatrix SparseMatrix
The coalesced sparse matrix. The coalesced sparse matrix
Examples Examples
-------- --------
...@@ -397,12 +396,7 @@ class SparseMatrix: ...@@ -397,12 +396,7 @@ class SparseMatrix:
return SparseMatrix(self.c_sparse_matrix.coalesce()) return SparseMatrix(self.c_sparse_matrix.coalesce())
def has_duplicate(self): def has_duplicate(self):
"""Return whether this sparse matrix contains duplicate indices. """Returns ``True`` if the sparse matrix contains duplicate indices.
Returns
-------
bool
True if this sparse matrix contains duplicate indices.
Examples Examples
-------- --------
...@@ -424,14 +418,14 @@ def from_coo( ...@@ -424,14 +418,14 @@ def from_coo(
val: Optional[torch.Tensor] = None, val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None, shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix: ) -> SparseMatrix:
"""Create a sparse matrix from row and column coordinates. """Creates a sparse matrix from row and column coordinates.
Parameters Parameters
---------- ----------
row : tensor row : tensor
The row indices of shape (nnz). The row indices of shape (nnz)
col : tensor col : tensor
The column indices of shape (nnz). The column indices of shape (nnz)
val : tensor, optional val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of The values of shape (nnz) or (nnz, D). If None, it will be a tensor of
shape (nnz) filled by 1. shape (nnz) filled by 1.
...@@ -492,11 +486,11 @@ def from_csr( ...@@ -492,11 +486,11 @@ def from_csr(
val: Optional[torch.Tensor] = None, val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None, shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix: ) -> SparseMatrix:
"""Create a sparse matrix from CSR indices. """Creates a sparse matrix from CSR indices.
For row i of the sparse matrix For row i of the sparse matrix
- the column indices of the nonzero entries are stored in - the column indices of the non-zero elements are stored in
``indices[indptr[i]: indptr[i+1]]`` ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]`` - the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
...@@ -504,9 +498,9 @@ def from_csr( ...@@ -504,9 +498,9 @@ def from_csr(
---------- ----------
indptr : tensor indptr : tensor
Pointer to the column indices of shape (N + 1), where N is the number Pointer to the column indices of shape (N + 1), where N is the number
of rows. of rows
indices : tensor indices : tensor
The column indices of shape (nnz). The column indices of shape (nnz)
val : tensor, optional val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of The values of shape (nnz) or (nnz, D). If None, it will be a tensor of
shape (nnz) filled by 1. shape (nnz) filled by 1.
...@@ -576,11 +570,11 @@ def from_csc( ...@@ -576,11 +570,11 @@ def from_csc(
val: Optional[torch.Tensor] = None, val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None, shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix: ) -> SparseMatrix:
"""Create a sparse matrix from CSC indices. """Creates a sparse matrix from CSC indices.
For column i of the sparse matrix For column i of the sparse matrix
- the row indices of the nonzero entries are stored in - the row indices of the non-zero elements are stored in
``indices[indptr[i]: indptr[i+1]]`` ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]`` - the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
...@@ -588,9 +582,9 @@ def from_csc( ...@@ -588,9 +582,9 @@ def from_csc(
---------- ----------
indptr : tensor indptr : tensor
Pointer to the row indices of shape N + 1, where N is the Pointer to the row indices of shape N + 1, where N is the
number of columns. number of columns
indices : tensor indices : tensor
The row indices of shape nnz. The row indices of shape nnz
val : tensor, optional val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of The values of shape (nnz) or (nnz, D). If None, it will be a tensor of
shape (nnz) filled by 1. shape (nnz) filled by 1.
...@@ -655,17 +649,17 @@ def from_csc( ...@@ -655,17 +649,17 @@ def from_csc(
def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix: def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
"""Create a sparse matrix from an existing sparse matrix using new values. """Creates a sparse matrix from an existing sparse matrix using new values.
The new sparse matrix will have the same nonzero indices as the given The new sparse matrix will have the same non-zero indices as the given
sparse matrix and use the given values as the new nonzero values. sparse matrix and use the given values as the new non-zero values.
Parameters Parameters
---------- ----------
mat : SparseMatrix mat : SparseMatrix
An existing sparse matrix with nnz nonzero values An existing sparse matrix with non-zero values
val : tensor val : tensor
The new nonzero values, a tensor of shape (nnz) or (nnz, D) The new values of the non-zero elements, a tensor of shape (nnz) or (nnz, D)
Returns Returns
------- -------
...@@ -691,7 +685,8 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix: ...@@ -691,7 +685,8 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
def _sparse_matrix_str(spmat: SparseMatrix) -> str: def _sparse_matrix_str(spmat: SparseMatrix) -> str:
"""Internal function for converting a sparse matrix to string """Internal function for converting a sparse matrix to string
representation.""" representation.
"""
indices_str = str(torch.stack(spmat.coo())) indices_str = str(torch.stack(spmat.coo()))
values_str = str(spmat.val) values_str = str(spmat.val)
meta_str = f"size={spmat.shape}, nnz={spmat.nnz}" meta_str = f"size={spmat.shape}, nnz={spmat.nnz}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment