Unverified Commit 9c647d80 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Add elementwise mul and div method, and polish the python doc. (#5142)



* add mul and div

* polish

* equivalent

* indent

* double
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 491c19fe
......@@ -5,18 +5,19 @@ from typing import Union
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix
__all__ = ["add", "power"]
__all__ = ["add", "sub", "mul", "div", "power"]
def add(
A: Union[DiagMatrix, SparseMatrix], B: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise additions for `DiagMatrix` and `SparseMatrix`.
"""Elementwise additions for ``DiagMatrix`` and ``SparseMatrix``. This is
equivalent to ``A + B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \ B | DiagMatrix | SparseMatrix | scalar |
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
......@@ -45,7 +46,7 @@ def add(
>>> val = torch.tensor([10, 20, 30])
>>> A = from_coo(row, col, val)
>>> B = diag(torch.arange(1, 4))
>>> A + B
>>> add(A, B)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 1, 0, 1, 2]]),
values=tensor([ 1, 20, 10, 2, 33]),
......@@ -55,12 +56,13 @@ def add(
def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
"""Elementwise subtraction for `DiagMatrix` and `SparseMatrix`.
"""Elementwise subtraction for ``DiagMatrix`` and ``SparseMatrix``. This is
equivalent to ``A - B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \ B | DiagMatrix | SparseMatrix | scalar |
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
......@@ -85,22 +87,125 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
--------
>>> A = diag(torch.arange(1, 4))
>>> B = diag(torch.arange(10, 13))
>>> A - B
>>> sub(A, B)
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
return A - B
def mul(
A: Union[SparseMatrix, DiagMatrix, float, int],
B: Union[SparseMatrix, DiagMatrix, float, int],
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``. This
is equivalent to ``A * B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| scalar | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
Parameters
----------
A : SparseMatrix or DiagMatrix or float or int
Sparse matrix or diagonal matrix or scalar value
B : SparseMatrix or DiagMatrix or float or int
Sparse matrix or diagonal matrix or scalar value
Returns
-------
SparseMatrix or DiagMatrix
Either sparse matrix or diagonal matrix
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = from_coo(row, col, val)
>>> mul(A, 2)
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([20, 40, 60]),
shape=(3, 4), nnz=3)
>>> D = diag(torch.arange(1, 4))
>>> mul(D, 2)
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
>>> D = diag(torch.arange(1, 4))
>>> mul(D, D)
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
return A * B
def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, float, int]
) -> Union[DiagMatrix]:
"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``. This is
equivalent to ``A / B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters
----------
A : DiagMatrix
Diagonal matrix
B : DiagMatrix or float or int
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> A = diag(torch.arange(1, 4))
>>> B = diag(torch.arange(10, 13))
>>> div(A, B)
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3))
>>> A = diag(torch.arange(1, 4))
>>> div(A, 2)
DiagMatrix(val=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 3))
"""
return A / B
def power(
A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int]
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise exponentiation for `DiagMatrix` and `SparseMatrix`.
"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``. This
is equivalent to ``A ** scalar``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \ B | DiagMatrix | SparseMatrix | scalar |
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
......@@ -123,7 +228,6 @@ def power(
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment