"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "84db2ac4572dd23b67d93d08660426e44f97ba75"
Unverified Commit acc567aa authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Support sparse matrix dividing scalar (#5173)



* use NotImplemented

* format

* extend to pytorch scalar

* sparse div scalar

* oops

* Apply suggestions from code review
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent b5c5c860
...@@ -152,8 +152,8 @@ def mul( ...@@ -152,8 +152,8 @@ def mul(
def div( def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar] A: Union[SparseMatrix, DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[DiagMatrix]: ) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``. to ``A / B``.
...@@ -164,15 +164,15 @@ def div( ...@@ -164,15 +164,15 @@ def div(
+--------------+------------+--------------+--------+ +--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ | | DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+ +--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | 🚫 | | SparseMatrix | 🚫 | 🚫 | |
+--------------+------------+--------------+--------+ +--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 | | scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+ +--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : DiagMatrix A : SparseMatrix or DiagMatrix
Diagonal matrix Sparse or diagonal matrix
B : DiagMatrix or Scalar B : DiagMatrix or Scalar
Diagonal matrix or scalar value Diagonal matrix or scalar value
...@@ -193,6 +193,16 @@ def div( ...@@ -193,6 +193,16 @@ def div(
>>> div(A, 2) >>> div(A, 2)
DiagMatrix(val=tensor([0.5000, 1.0000, 1.5000]), DiagMatrix(val=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 3)) shape=(3, 3))
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([1, 2, 3])
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A / 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 4), nnz=3)
""" """
return A / B return A / B
......
...@@ -89,6 +89,40 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix: ...@@ -89,6 +89,40 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
return NotImplemented return NotImplemented
def sp_div(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise division
Parameters
----------
A : SparseMatrix
First operand
B : Scalar
Second operand
Returns
-------
SparseMatrix
Result of A / B
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([1, 2, 3])
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A / 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 4), nnz=3)
"""
if is_scalar(B):
return val_like(A, A.val / B)
# Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented
# is returned.
return NotImplemented
def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix: def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with """Take the power of each nonzero element and return a sparse matrix with
the result. the result.
...@@ -125,4 +159,5 @@ def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix: ...@@ -125,4 +159,5 @@ def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
SparseMatrix.__add__ = sp_add SparseMatrix.__add__ = sp_add
SparseMatrix.__mul__ = sp_mul SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__truediv__ = sp_div
SparseMatrix.__pow__ = sp_power SparseMatrix.__pow__ = sp_power
...@@ -23,7 +23,7 @@ def all_close_sparse(A, row, col, val, shape): ...@@ -23,7 +23,7 @@ def all_close_sparse(A, row, col, val, shape):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)] "v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
) )
def test_mul_scalar(v_scalar): def test_muldiv_scalar(v_scalar):
ctx = F.ctx() ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx) row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx) col = torch.tensor([0, 3, 2]).to(ctx)
...@@ -40,6 +40,15 @@ def test_mul_scalar(v_scalar): ...@@ -40,6 +40,15 @@ def test_mul_scalar(v_scalar):
assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4) assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape assert A1.shape == A2.shape
# A / v
A2 = A1 / v_scalar
assert torch.allclose(A1.val / v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape
# v / A
with pytest.raises(TypeError):
v_scalar / A1
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) @pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape): def test_pow(val_shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment