Unverified Commit 82eb3d71 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] polish strings of matmul.py and sddmm.py. (#5177)



* polish mm docstrings

* fix typo

* small fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 84aae086
...@@ -17,14 +17,14 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -17,14 +17,14 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz) Sparse matrix of shape ``(L, M)`` with scalar values
X : torch.Tensor X : torch.Tensor
Dense tensor of shape (M, F) or (M) Dense matrix of shape ``(M, N)`` or ``(M)``
Returns Returns
------- -------
torch.Tensor torch.Tensor
The multiplication result of shape (N, F) or (N) The dense matrix of shape ``(L, N)`` or ``(L)``
Examples Examples
-------- --------
...@@ -42,10 +42,10 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -42,10 +42,10 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}" ), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
X, torch.Tensor X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}" ), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
# The input is a DiagMatrix. Cast it to SparseMatrix # The input is a DiagMatrix. Cast it to SparseMatrix
if not isinstance(A, SparseMatrix): if not isinstance(A, SparseMatrix):
...@@ -60,14 +60,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -60,14 +60,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M, B) with values of shape (nnz) Sparse matrix of shape ``(L, M)`` with vector values of length ``K``
X : torch.Tensor X : torch.Tensor
Dense tensor of shape (M, F, B) Dense matrix of shape ``(M, N, K)``
Returns Returns
------- -------
torch.Tensor torch.Tensor
The multiplication result of shape (N, F, B) Dense matrix of shape ``(L, N, K)``
Examples Examples
-------- --------
...@@ -85,27 +85,27 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -85,27 +85,27 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}" ), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
X, torch.Tensor X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}" ), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
return spmm(A, X) return spmm(A, X)
def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix: def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix """Internal function for multiplying a diagonal matrix by a diagonal matrix.
Parameters Parameters
---------- ----------
A : DiagMatrix A : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1) Diagonal matrix of shape ``(L, M)``
B : DiagMatrix B : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2) Diagonal matrix of shape ``(M, N)``
Returns Returns
------- -------
DiagMatrix DiagMatrix
The result of multiplication. Diagonal matrix of shape ``(L, N)``
""" """
M, N = A.shape M, N = A.shape
N, P = B.shape N, P = B.shape
...@@ -124,14 +124,14 @@ def _sparse_diag_mm(A, D): ...@@ -124,14 +124,14 @@ def _sparse_diag_mm(A, D):
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Matrix of shape (N, M), with values of shape (nnz1) Sparse matrix of shape ``(L, M)``
D : DiagMatrix D : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2) Diagonal matrix of shape ``(M, N)``
Returns Returns
------- -------
SparseMatrix SparseMatrix
SparseMatrix with shape (N, P) Sparse matrix of shape ``(L, N)``
""" """
assert ( assert (
A.shape[1] == D.shape[0] A.shape[1] == D.shape[0]
...@@ -142,24 +142,24 @@ def _sparse_diag_mm(A, D): ...@@ -142,24 +142,24 @@ def _sparse_diag_mm(A, D):
assert ( assert (
D.shape[0] == D.shape[1] D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \ ), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \
but got {D.shape}" but got {D.shape}."
return val_like(A, D.val[A.col] * A.val) return val_like(A, D.val[A.col] * A.val)
def _diag_sparse_mm(D, A): def _diag_sparse_mm(D, A):
"""Internal function for multiplying a diag matrix by a sparse matrix. """Internal function for multiplying a diagonal matrix by a sparse matrix.
Parameters Parameters
---------- ----------
D : DiagMatrix D : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1) Diagonal matrix of shape ``(L, M)``
A : DiagMatrix A : SparseMatrix
Matrix of shape (M, P), with values of shape (nnz2) Sparse matrix of shape ``(M, N)``
Returns Returns
------- -------
SparseMatrix SparseMatrix
SparseMatrix with shape (N, P) Sparse matrix of shape ``(L, N)``
""" """
assert ( assert (
D.shape[1] == A.shape[0] D.shape[1] == A.shape[0]
...@@ -170,7 +170,7 @@ def _diag_sparse_mm(D, A): ...@@ -170,7 +170,7 @@ def _diag_sparse_mm(D, A):
assert ( assert (
D.shape[0] == D.shape[1] D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \ ), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \
but got {D.shape}" but got {D.shape}."
return val_like(A, D.val[A.row] * A.val) return val_like(A, D.val[A.row] * A.val)
...@@ -184,15 +184,15 @@ def spspmm( ...@@ -184,15 +184,15 @@ def spspmm(
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz) Sparse matrix of shape ``(L, M)``
B : SparseMatrix or DiagMatrix B : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz) Sparse matrix of shape ``(M, N)``
Returns Returns
------- -------
SparseMatrix or DiagMatrix SparseMatrix or DiagMatrix
The result of multiplication. It is a DiagMatrix object if both matrices Matrix of shape ``(L, N)``. It is a DiagMatrix object if both matrices
are DiagMatrix objects. It is a SparseMatrix object otherwise. are DiagMatrix objects, otherwise a SparseMatrix object.
Examples Examples
-------- --------
...@@ -215,10 +215,10 @@ def spspmm( ...@@ -215,10 +215,10 @@ def spspmm(
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}" ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
B, (SparseMatrix, DiagMatrix) B, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}" ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}."
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix): if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B) return _diag_diag_mm(A, B)
......
...@@ -25,16 +25,16 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix: ...@@ -25,16 +25,16 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix of shape ``(M, N)``. Sparse matrix of shape ``(L, N)``
X1 : Tensor X1 : torch.Tensor
Dense matrix of shape ``(M, K)`` or ``(M,)`` Dense matrix of shape ``(L, M)`` or ``(L,)``
X2 : Tensor X2 : torch.Tensor
Dense matrix of shape ``(K, N)`` or ``(N,)`` Dense matrix of shape ``(M, N)`` or ``(N,)``
Returns Returns
------- -------
SparseMatrix SparseMatrix
Sparse matrix of shape ``(M, N)``. Sparse matrix of shape ``(L, N)``
Examples Examples
-------- --------
...@@ -58,32 +58,33 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix: ...@@ -58,32 +58,33 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
def bsddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix: def bsddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches. r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.
``sddmm`` multiplies two dense matrices :attr:`X1` and :attr:`X2` at the ``sddmm`` matrix-multiplies two dense matrices :attr:`X1` and :attr:`X2`,
nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A` is not then elementwise-multiplies the result with sparse matrix :attr:`A` at the
considered during the computation. nonzero locations.
Mathematically ``sddmm`` is formulated as: Mathematically ``sddmm`` is formulated as:
.. math:: .. math::
out = (X1 @ X2) * A out = (X1 @ X2) * A
The batch dimension is the last dimension for input matrices. In particular, The batch dimension is the last dimension for input dense matrices. In
if the sparse matrix has scalar non-zero values, it will be broadcasted particular, if the sparse matrix has scalar non-zero values, it will be
for bsddmm. broadcasted for bsddmm.
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix of shape ``(M, N)`` or ``(M, N, B)``. Sparse matrix of shape ``(L, N)`` with scalar values or vector values of
length ``K``
X1 : Tensor X1 : Tensor
Dense matrix of shape ``(M, K, B)`` Dense matrix of shape ``(L, M, K)``
X2 : Tensor X2 : Tensor
Dense matrix of shape ``(K, N, B)`` Dense matrix of shape ``(M, N, K)``
Returns Returns
------- -------
SparseMatrix SparseMatrix
Sparse matrix of shape ``(M, N, B)``. Sparse matrix of shape ``(L, N)`` with vector values of length ``K``
Examples Examples
-------- --------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment