"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bf5ca036fa7fbd6b46dc67df76d782eb90a860ca"
Unverified Commit 8c3e7830 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Elementwise multiplication of a SparseMatrix object by a scalar (#5076)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 0f1bcd99
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from .diag_matrix import DiagMatrix from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import SparseMatrix, val_like
__all__ = ["sp_add", "sp_power"] __all__ = ["sp_add", "sp_mul", "sp_power"]
def spsp_add(A, B): def spsp_add(A, B):
...@@ -54,6 +54,53 @@ def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix: ...@@ -54,6 +54,53 @@ def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix:
) )
def sp_mul(
A: Union[SparseMatrix, float, int], B: Union[SparseMatrix, float, int]
) -> SparseMatrix:
"""Elementwise multiplication
Parameters
----------
A : SparseMatrix or float or int
First operand
B : SparseMatrix or float or int
Second operand
Returns
-------
SparseMatrix
Result of A * B
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([1, 2, 3])
>>> A = create_from_coo(row, col, val, shape=(3, 4))
>>> A * 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3)
>>> 2 * A
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3)
"""
if isinstance(A, SparseMatrix) and isinstance(B, (float, int)):
return val_like(A, A.val * B)
elif isinstance(A, (float, int)) and isinstance(B, SparseMatrix):
return val_like(B, A * B.val)
raise RuntimeError(
"Elementwise multiplication between "
f"{type(A)} and {type(B)} is not supported."
)
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with """Take the power of each nonzero element and return a sparse matrix with
the result. the result.
...@@ -107,5 +154,7 @@ def sp_rpower(A: SparseMatrix, scalar: Union[float, int]): ...@@ -107,5 +154,7 @@ def sp_rpower(A: SparseMatrix, scalar: Union[float, int]):
SparseMatrix.__add__ = sp_add SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = sp_add SparseMatrix.__radd__ = sp_add
SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__pow__ = sp_power SparseMatrix.__pow__ = sp_power
SparseMatrix.__rpow__ = sp_rpower SparseMatrix.__rpow__ = sp_rpower
import operator
import sys import sys
import backend as F import backend as F
...@@ -21,6 +20,25 @@ def all_close_sparse(A, row, col, val, shape): ...@@ -21,6 +20,25 @@ def all_close_sparse(A, row, col, val, shape):
assert A.shape == shape assert A.shape == shape
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_mul_scalar(v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(len(row)).to(ctx)
A1 = create_from_coo(row, col, val, shape=(3, 4))
# A * v
A2 = A1 * v_scalar
assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape
# v * A
A2 = v_scalar * A1
assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) @pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape): def test_pow(val_shape):
# A ** v # A ** v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment