"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "69467ea59003e950152e2abb9e447807c45cad79"
Unverified Commit 354a2110 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Elementwise ops for DiagMatrix and Elementwise Power for SparseMatrix (#5024)



* Update

* Update

* lint

* CI

* Update
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 32dc1af6
......@@ -6,7 +6,7 @@ from .elementwise_op_diag import diag_add
from .elementwise_op_sp import sp_add
from .sparse_matrix import SparseMatrix
__all__ = ["add"]
__all__ = ["add", "power"]
def add(
......@@ -16,3 +16,42 @@ def add(
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return diag_add(A, B)
return sp_add(A, B)
def power(
A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int]
) -> Union[SparseMatrix, DiagMatrix]:
"""Take the power of each nonzero element and return a matrix with
the result.
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
scalar : float or int
Exponent
Returns
-------
SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix, same type as A
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = create_from_coo(row, col, val)
>>> power(A, 2)
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([100, 400, 900]),
shape=(3, 4), nnz=3)
>>> D = diag(torch.arange(1, 4))
>>> power(D, 2)
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
return A**scalar
......@@ -7,7 +7,7 @@ __all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise addition.
"""Elementwise addition
Parameters
----------
......@@ -42,7 +42,7 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise subtraction.
"""Elementwise subtraction
Parameters
----------
......@@ -79,7 +79,7 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
def diag_mul(
D1: Union[DiagMatrix, float, int], D2: Union[DiagMatrix, float, int]
) -> DiagMatrix:
"""Elementwise multiplication.
"""Elementwise multiplication
Parameters
----------
......@@ -120,15 +120,17 @@ def diag_mul(
)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
"""Elementwise division.
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
"""Elementwise division of a diagonal matrix by a diagonal matrix or a
scalar
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
D2 : DiagMatrix or float or int
Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix,
division is only applied to the diagonal elements.
Returns
-------
......@@ -146,41 +148,50 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
f"The shape of diagonal matrix D1 {D1.shape} and D2 {D2.shape} "
"must match."
)
return diag(D1.val / D2.val, D1.shape)
return diag(D1.val / D2, D1.shape)
elif isinstance(D2, (float, int)):
assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape)
raise RuntimeError(
f"Elementwise division between a diagonal matrix and {type(D2)} is "
"not supported."
)
def diag_rdiv(D1: float, D2: DiagMatrix):
"""Elementwise division.
def diag_rdiv(D1: DiagMatrix, D2: Union[float, int]):
"""Function for preventing elementwise division of a scalar by a diagonal
matrix
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
D1 : DiagMatrix
Diagonal matrix
D2 : float or int
Scalar value
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
f"Elementwise division of {type(D2)} by a diagonal matrix is not "
"supported."
)
def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
"""Elementwise power operation.
# pylint: disable=invalid-name
def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
"""Take the power of each nonzero element and return a diagonal matrix with
the result.
Parameters
----------
D1 : DiagMatrix
D : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value.
scalar : float or int
Exponent
Returns
-------
......@@ -189,34 +200,32 @@ def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
Examples
--------
>>> D1 = diag(torch.arange(1, 4))
>>> pow(D1, 2)
>>> D = diag(torch.arange(1, 4))
>>> D ** 2
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(pow(D1.val, D2.val))
return DiagMatrix(pow(D1.val, D2))
if isinstance(scalar, (float, int)):
return diag(D.val**scalar, D.shape)
raise RuntimeError(
f"Raising a diagonal matrix to exponent {type(scalar)} is not allowed."
)
def diag_rpower(D1: float, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise power operator.
def diag_rpower(D: DiagMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a diagonal matrix exponent
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
D : DiagMatrix
Diagonal matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
f"Raising {type(scalar)} to a diagonal matrix component is not "
"allowed."
)
......
"""DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch
from .sparse_matrix import SparseMatrix
from .sparse_matrix import SparseMatrix, val_like
__all__ = ["sp_add"]
__all__ = ["sp_add", "sp_power"]
def spsp_add(A, B):
......@@ -49,5 +51,58 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
)
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with
the result.
Parameters
----------
A : SparseMatrix
Sparse matrix
scalar : float or int
Exponent
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = create_from_coo(row, col, val)
>>> A ** 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([100, 400, 900]),
shape=(3, 4), nnz=3)
"""
if isinstance(scalar, (float, int)):
return val_like(A, A.val**scalar)
raise RuntimeError(
f"Raising a sparse matrix to exponent {type(scalar)} is not allowed."
)
def sp_rpower(A: SparseMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a sparse matrix exponent
Parameters
----------
A : SparseMatrix
Sparse matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
f"Raising {type(scalar)} to a sparse matrix component is not allowed."
)
SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = sp_add
SparseMatrix.__pow__ = sp_power
SparseMatrix.__rpow__ = sp_rpower
......@@ -2,10 +2,10 @@ import operator
import sys
import backend as F
import numpy as np
import pytest
import torch
from dgl.mock_sparse2 import diag
from dgl.mock_sparse2 import diag, power
# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
......@@ -54,6 +54,11 @@ def test_diag_op_scalar(v_scalar):
# D ^ v
D1 = diag(torch.arange(1, 4).to(ctx))
D2 = D1 ** v_scalar
assert torch.allclose(D1.val ** v_scalar, D2.val, rtol=1e-4, atol=1e-4)
D2 = D1**v_scalar
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# pow(D, v)
D2 = power(D1, v_scalar)
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
import operator
import sys
import numpy as np
import backend as F
import pytest
import torch
import sys
import dgl
from dgl.mock_sparse2 import create_from_coo, diag
from dgl.mock_sparse2 import create_from_coo, power
# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)
def all_close_sparse(A, row, col, val, shape):
rowA, colA = A.coo()
valA = A.val
......@@ -22,11 +23,12 @@ def all_close_sparse(A, row, col, val, shape):
@pytest.mark.parametrize("op", [operator.add])
def test_sparse_op_sparse(op):
rowA = torch.tensor([1, 0, 2, 7, 1])
colA = torch.tensor([0, 49, 2, 1, 7])
valA = torch.rand(len(rowA))
ctx = F.ctx()
rowA = torch.tensor([1, 0, 2, 7, 1]).to(ctx)
colA = torch.tensor([0, 49, 2, 1, 7]).to(ctx)
valA = torch.rand(len(rowA)).to(ctx)
A = create_from_coo(rowA, colA, valA, shape=(10, 50))
w = torch.rand(len(rowA))
w = torch.rand(len(rowA)).to(ctx)
A1 = create_from_coo(rowA, colA, w, shape=(10, 50))
def _test():
......@@ -35,20 +37,26 @@ def test_sparse_op_sparse(op):
_test()
@pytest.mark.skip(
reason="No way to test it because we does not element-wise op \
between matrices with different sparsity"
)
@pytest.mark.parametrize("op", [operator.add])
def test_sparse_op_diag(op):
rowA = torch.tensor([1, 0, 2, 7, 1])
colA = torch.tensor([0, 49, 2, 1, 7])
valA = torch.rand(len(rowA))
A = create_from_coo(rowA, colA, valA, shape=(10, 50))
D = diag(torch.arange(2, 12), shape=A.shape)
D_sp = D.as_sparse()
def _test():
all_close_sparse(op(A, D), *D_sp.coo(), [10, 50])
_test()
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape):
# A ** v
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(val_shape).to(ctx)
A = create_from_coo(row, col, val, shape=(3, 4))
exponent = 2
A_new = A**exponent
assert torch.allclose(A_new.val, val**exponent)
assert A_new.shape == A.shape
new_row, new_col = A_new.coo()
assert torch.allclose(new_row, row)
assert torch.allclose(new_col, col)
# power(A, v)
A_new = power(A, exponent)
assert torch.allclose(A_new.val, val**exponent)
assert A_new.shape == A.shape
new_row, new_col = A_new.coo()
assert torch.allclose(new_row, row)
assert torch.allclose(new_col, col)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment