Unverified Commit 82eb3d71 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] polish strings of matmul.py and sddmm.py. (#5177)



* polish mm docstrings

* fix typo

* small fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 84aae086
......@@ -17,14 +17,14 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
Sparse matrix of shape ``(L, M)`` with scalar values
X : torch.Tensor
Dense tensor of shape (M, F) or (M)
Dense matrix of shape ``(M, N)`` or ``(M)``
Returns
-------
torch.Tensor
The multiplication result of shape (N, F) or (N)
The dense matrix of shape ``(L, N)`` or ``(L)``
Examples
--------
......@@ -42,10 +42,10 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""
assert isinstance(
A, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance(
X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}"
), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
# The input is a DiagMatrix. Cast it to SparseMatrix
if not isinstance(A, SparseMatrix):
......@@ -60,14 +60,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M, B) with values of shape (nnz)
Sparse matrix of shape ``(L, M)`` with vector values of length ``K``
X : torch.Tensor
Dense tensor of shape (M, F, B)
Dense matrix of shape ``(M, N, K)``
Returns
-------
torch.Tensor
The multiplication result of shape (N, F, B)
Dense matrix of shape ``(L, N, K)``
Examples
--------
......@@ -85,27 +85,27 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""
assert isinstance(
A, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance(
X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}"
), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
return spmm(A, X)
def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix
"""Internal function for multiplying a diagonal matrix by a diagonal matrix.
Parameters
----------
A : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
Diagonal matrix of shape ``(L, M)``
B : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Diagonal matrix of shape ``(M, N)``
Returns
-------
DiagMatrix
The result of multiplication.
Diagonal matrix of shape ``(L, N)``
"""
M, N = A.shape
N, P = B.shape
......@@ -124,14 +124,14 @@ def _sparse_diag_mm(A, D):
Parameters
----------
A : SparseMatrix
Matrix of shape (N, M), with values of shape (nnz1)
Sparse matrix of shape ``(L, M)``
D : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Diagonal matrix of shape ``(M, N)``
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
Sparse matrix of shape ``(L, N)``
"""
assert (
A.shape[1] == D.shape[0]
......@@ -142,24 +142,24 @@ def _sparse_diag_mm(A, D):
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \
but got {D.shape}"
but got {D.shape}."
return val_like(A, D.val[A.col] * A.val)
def _diag_sparse_mm(D, A):
"""Internal function for multiplying a diag matrix by a sparse matrix.
"""Internal function for multiplying a diagonal matrix by a sparse matrix.
Parameters
----------
D : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Diagonal matrix of shape ``(L, M)``
A : SparseMatrix
Sparse matrix of shape ``(M, N)``
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
Sparse matrix of shape ``(L, N)``
"""
assert (
D.shape[1] == A.shape[0]
......@@ -170,7 +170,7 @@ def _diag_sparse_mm(D, A):
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \
but got {D.shape}"
but got {D.shape}."
return val_like(A, D.val[A.row] * A.val)
......@@ -184,15 +184,15 @@ def spspmm(
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
Sparse matrix of shape ``(L, M)``
B : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz)
Sparse matrix of shape ``(M, N)``
Returns
-------
SparseMatrix or DiagMatrix
The result of multiplication. It is a DiagMatrix object if both matrices
are DiagMatrix objects. It is a SparseMatrix object otherwise.
Matrix of shape ``(L, N)``. It is a DiagMatrix object if both matrices
are DiagMatrix objects, otherwise a SparseMatrix object.
Examples
--------
......@@ -215,10 +215,10 @@ def spspmm(
"""
assert isinstance(
A, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}."
assert isinstance(
B, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}"
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}."
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B)
......
......@@ -25,16 +25,16 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(M, N)``.
X1 : Tensor
Dense matrix of shape ``(M, K)`` or ``(M,)``
X2 : Tensor
Dense matrix of shape ``(K, N)`` or ``(N,)``
Sparse matrix of shape ``(L, N)``
X1 : torch.Tensor
Dense matrix of shape ``(L, M)`` or ``(L,)``
X2 : torch.Tensor
Dense matrix of shape ``(M, N)`` or ``(N,)``
Returns
-------
SparseMatrix
Sparse matrix of shape ``(M, N)``.
Sparse matrix of shape ``(L, N)``
Examples
--------
......@@ -58,32 +58,33 @@ def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
def bsddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.
``sddmm`` multiplies two dense matrices :attr:`X1` and :attr:`X2` at the
nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A` is not
considered during the computation.
``sddmm`` matrix-multiplies two dense matrices :attr:`X1` and :attr:`X2`,
then elementwise-multiplies the result with sparse matrix :attr:`A` at the
nonzero locations.
Mathematically ``sddmm`` is formulated as:
.. math::
out = (X1 @ X2) * A
The batch dimension is the last dimension for input matrices. In particular,
if the sparse matrix has scalar non-zero values, it will be broadcasted
for bsddmm.
The batch dimension is the last dimension for input dense matrices. In
particular, if the sparse matrix has scalar non-zero values, it will be
broadcasted for bsddmm.
Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(M, N)`` or ``(M, N, B)``.
Sparse matrix of shape ``(L, N)`` with scalar values or vector values of
length ``K``
X1 : Tensor
Dense matrix of shape ``(M, K, B)``
Dense matrix of shape ``(L, M, K)``
X2 : Tensor
Dense matrix of shape ``(K, N, B)``
Dense matrix of shape ``(M, N, K)``
Returns
-------
SparseMatrix
Sparse matrix of shape ``(M, N, B)``.
Sparse matrix of shape ``(L, N)`` with vector values of length ``K``
Examples
--------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment