Unverified Commit 3d01340d authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Sparse] Add elementwise operators for diagonal matrix (#4644)



* Add elementwise operators for diagonal matrix

* minor

* resolve ambiguity of ops

* lint check

* lint check

* lint check
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent eb729c54
"""dgl sparse class.""" """dgl sparse class."""
from .diag_matrix import * from .diag_matrix import *
from .sp_matrix import * from .sp_matrix import *
from .elementwise_op_sp import * from .elementwise_op import *
from .sddmm import * from .sddmm import *
from .reduction import * # pylint: disable=W0622 from .reduction import * # pylint: disable=W0622
from .unary_diag import * from .unary_diag import *
......
"""DGL elementwise operator module."""
from typing import Union
from .diag_matrix import DiagMatrix
from .elementwise_op_diag import (
diag_add,
diag_sub,
diag_mul,
diag_div,
diag_power,
)
from .elementwise_op_sp import sp_add, sp_sub, sp_mul, sp_div, sp_power
from .sp_matrix import SparseMatrix
__all__ = ["add", "sub", "mul", "div", "power"]
def add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise addition"""
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return diag_add(A, B)
return sp_add(A, B)
def sub(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise addition"""
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return diag_sub(A, B)
return sp_sub(A, B)
def mul(
A: Union[SparseMatrix, DiagMatrix, float],
B: Union[SparseMatrix, DiagMatrix, float],
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise multiplication"""
if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
return sp_mul(A, B)
return diag_mul(A, B)
def div(
A: Union[SparseMatrix, DiagMatrix],
B: Union[SparseMatrix, DiagMatrix, float],
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise division"""
if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
return sp_div(A, B)
return diag_div(A, B)
def power(
A: Union[SparseMatrix, DiagMatrix], B: float
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise division"""
if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
return sp_power(A, B)
return diag_power(A, B)
"""DGL elementwise operators for diagonal matrix module."""
from typing import Union
from .diag_matrix import DiagMatrix
__all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise addition.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix
Diagonal matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 + D2
DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" " D2 {} must match.".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val + D2.val)
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise subtraction.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix
Diagonal matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 -D2
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val - D2.val)
def diag_mul(
D1: Union[DiagMatrix, float], D2: Union[DiagMatrix, float]
) -> DiagMatrix:
"""Elementwise multiplication.
Parameters
----------
D1 : DiagMatrix or scalar
Diagonal matrix or scalar value
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
DiagMatrix(val=tensor([10, 22, 36]),
shape=(3, 3))
>>> D1 * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3))
>>> 2 * D1
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val * D2.val)
return DiagMatrix(D1.val * D2)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
"""Elementwise division.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 / D2
>>> D1/D2
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3))
>>> D1/2.5
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val / D2.val)
return DiagMatrix(D1.val / D2)
def diag_rdiv(D1: float, D2: DiagMatrix):
"""Elementwise division.
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
Diagonal matrix
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
)
def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
"""Elementwise power operation.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value.
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> pow(D1, 2)
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(pow(D1.val, D2.val))
return DiagMatrix(pow(D1.val, D2))
def diag_rpower(D1: float, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise power operator.
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
Diagonal matrix
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
)
DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_sub
DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div
DiagMatrix.__rtruediv__ = diag_rdiv
DiagMatrix.__pow__ = diag_power
DiagMatrix.__rpow__ = diag_rpower
"""dgl elementwise operators for sparse matrix module.""" """DGL elementwise operators for sparse matrix module."""
from typing import Union from typing import Union
import torch import torch
from .sp_matrix import SparseMatrix
from .diag_matrix import DiagMatrix from .diag_matrix import DiagMatrix
from .sp_matrix import SparseMatrix
__all__ = ["add", "sub", "mul", "div", "rdiv", "power", "rpower"] __all__ = ["sp_add", "sp_sub", "sp_mul", "sp_div", "sp_power"]
def add( def sp_add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix] A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> SparseMatrix: ) -> SparseMatrix:
"""Elementwise addition. """Elementwise addition.
...@@ -81,7 +81,7 @@ def add( ...@@ -81,7 +81,7 @@ def add(
) )
def sub( def sp_sub(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix] A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> SparseMatrix: ) -> SparseMatrix:
"""Elementwise subtraction. """Elementwise subtraction.
...@@ -139,7 +139,7 @@ def sub( ...@@ -139,7 +139,7 @@ def sub(
) )
def mul( def sp_mul(
A: Union[SparseMatrix, DiagMatrix, float], A: Union[SparseMatrix, DiagMatrix, float],
B: Union[SparseMatrix, DiagMatrix, float], B: Union[SparseMatrix, DiagMatrix, float],
) -> SparseMatrix: ) -> SparseMatrix:
...@@ -205,7 +205,7 @@ def mul( ...@@ -205,7 +205,7 @@ def mul(
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape) return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
def div( def sp_div(
A: Union[SparseMatrix, DiagMatrix], A: Union[SparseMatrix, DiagMatrix],
B: Union[SparseMatrix, DiagMatrix, float], B: Union[SparseMatrix, DiagMatrix, float],
) -> SparseMatrix: ) -> SparseMatrix:
...@@ -260,7 +260,7 @@ def div( ...@@ -260,7 +260,7 @@ def div(
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape) return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
def rdiv(A: float, B: Union[SparseMatrix, DiagMatrix]): def sp_rdiv(A: float, B: Union[SparseMatrix, DiagMatrix]):
"""Elementwise division. """Elementwise division.
Parameters Parameters
...@@ -276,7 +276,7 @@ def rdiv(A: float, B: Union[SparseMatrix, DiagMatrix]): ...@@ -276,7 +276,7 @@ def rdiv(A: float, B: Union[SparseMatrix, DiagMatrix]):
) )
def power(A: SparseMatrix, B: float) -> SparseMatrix: def sp_power(A: SparseMatrix, B: float) -> SparseMatrix:
"""Elementwise power operation. """Elementwise power operation.
Parameters Parameters
...@@ -310,7 +310,7 @@ def power(A: SparseMatrix, B: float) -> SparseMatrix: ...@@ -310,7 +310,7 @@ def power(A: SparseMatrix, B: float) -> SparseMatrix:
return SparseMatrix(A.row, A.col, torch.pow(A.val, B), A.shape) return SparseMatrix(A.row, A.col, torch.pow(A.val, B), A.shape)
def rpower(A: float, B: SparseMatrix) -> SparseMatrix: def sp_rpower(A: float, B: SparseMatrix) -> SparseMatrix:
"""Elementwise power operation. """Elementwise power operation.
Parameters Parameters
...@@ -326,13 +326,13 @@ def rpower(A: float, B: SparseMatrix) -> SparseMatrix: ...@@ -326,13 +326,13 @@ def rpower(A: float, B: SparseMatrix) -> SparseMatrix:
) )
SparseMatrix.__add__ = add SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = add SparseMatrix.__radd__ = sp_add
SparseMatrix.__sub__ = sub SparseMatrix.__sub__ = sp_sub
SparseMatrix.__rsub__ = sub SparseMatrix.__rsub__ = sp_sub
SparseMatrix.__mul__ = mul SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = mul SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__truediv__ = div SparseMatrix.__truediv__ = sp_div
SparseMatrix.__rtruediv__ = rdiv SparseMatrix.__rtruediv__ = sp_rdiv
SparseMatrix.__pow__ = power SparseMatrix.__pow__ = sp_power
SparseMatrix.__rpow__ = rpower SparseMatrix.__rpow__ = sp_rpower
import operator
import numpy as np
import pytest
import torch
from dgl.mock_sparse import diag
parametrize_idtype = pytest.mark.parametrize(
"idtype", [torch.int32, torch.int64]
)
parametrize_dtype = pytest.mark.parametrize(
"dtype", [torch.float32, torch.float64]
)
def all_close_sparse(A, B):
assert torch.allclose(A.indices(), B.indices())
assert torch.allclose(A.values(), B.values())
assert A.shape == B.shape
@parametrize_idtype
@parametrize_dtype
@pytest.mark.parametrize(
"op", [operator.add, operator.sub, operator.mul, operator.truediv]
)
def test_diag_op_diag(idtype, dtype, op):
D1 = diag(torch.arange(1, 4))
D2 = diag(torch.arange(10, 13))
assert np.allclose(op(D1, D2).val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
@parametrize_idtype
@parametrize_dtype
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_diag_op_scalar(idtype, dtype, v_scalar):
D1 = diag(torch.arange(1, 50))
assert np.allclose(
D1.val * v_scalar, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
v_scalar * D1.val, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
D1.val / v_scalar, (D1 / v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
pow(D1.val, v_scalar), pow(D1, v_scalar).val, rtol=1e-4, atol=1e-4
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment