Unverified Commit e01580d3 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Misc fix for existing elementwise ops (#5003)



* Update

* Update

* Update

* Update

* update

* lint

* CI

* CI

* CI
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent aad3bd04
"""DGL elementwise operators for diagonal matrix module.""" """DGL elementwise operators for diagonal matrix module."""
from typing import Union from typing import Union
from .diag_matrix import DiagMatrix from .diag_matrix import diag, DiagMatrix
__all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"] __all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
...@@ -23,18 +23,22 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: ...@@ -23,18 +23,22 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
Examples Examples
-------- --------
>>> D1 = DiagMatrix(torch.arange(1, 4)) >>> D1 = diag(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13)) >>> D2 = diag(torch.arange(10, 13))
>>> D1 + D2 >>> D1 + D2
DiagMatrix(val=tensor([11, 13, 15]), DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3)) shape=(3, 3))
""" """
assert ( if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
D1.shape == D2.shape assert D1.shape == D2.shape, (
), "The shape of diagonal matrix D1 {} and" " D2 {} must match.".format( "The shape of diagonal matrix D1 "
D1.shape, D2.shape f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val + D2.val, D1.shape)
raise RuntimeError(
"Elementwise addition between "
f"{type(D1)} and {type(D2)} is not supported."
) )
return DiagMatrix(D1.val + D2.val)
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
...@@ -54,58 +58,66 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: ...@@ -54,58 +58,66 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
Examples Examples
-------- --------
>>> D1 = DiagMatrix(torch.arange(1, 4)) >>> D1 = diag(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13)) >>> D2 = diag(torch.arange(10, 13))
>>> D1 -D2 >>> D1 - D2
DiagMatrix(val=tensor([-9, -9, -9]), DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3)) shape=(3, 3))
""" """
assert ( if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
D1.shape == D2.shape assert D1.shape == D2.shape, (
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format( "The shape of diagonal matrix D1 "
D1.shape, D2.shape f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val - D2.val, D1.shape)
raise RuntimeError(
"Elementwise subtraction between "
f"{type(D1)} and {type(D2)} is not supported."
) )
return DiagMatrix(D1.val - D2.val)
def diag_mul( def diag_mul(
D1: Union[DiagMatrix, float], D2: Union[DiagMatrix, float] D1: Union[DiagMatrix, float, int], D2: Union[DiagMatrix, float, int]
) -> DiagMatrix: ) -> DiagMatrix:
"""Elementwise multiplication. """Elementwise multiplication.
Parameters Parameters
---------- ----------
D1 : DiagMatrix or scalar D1 : DiagMatrix or float or int
Diagonal matrix or scalar value Diagonal matrix or scalar value
D2 : DiagMatrix or scalar D2 : DiagMatrix or float or int
Diagonal matrix or scalar value Diagonal matrix or scalar value
Returns Returns
------- -------
DiagMatrix DiagMatrix
diagonal matrix Diagonal matrix
Examples Examples
-------- --------
>>> D1 = DiagMatrix(torch.arange(1, 4)) >>> D = diag(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13)) >>> D * 2.5
DiagMatrix(val=tensor([10, 22, 36]),
shape=(3, 3))
>>> D1 * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]), DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3)) shape=(3, 3))
>>> 2 * D1 >>> 2 * D
DiagMatrix(val=tensor([2, 4, 6]), DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3)) shape=(3, 3))
""" """
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix): if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert ( assert D1.shape == D2.shape, (
D1.shape == D2.shape "The shape of diagonal matrix D1 "
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format( f"{D1.shape} and D2 {D2.shape} must match."
D1.shape, D2.shape )
return diag(D1.val * D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, (float, int)):
return diag(D1.val * D2, D1.shape)
elif isinstance(D1, (float, int)) and isinstance(D2, DiagMatrix):
return diag(D1 * D2.val, D2.shape)
raise RuntimeError(
"Elementwise multiplication between "
f"{type(D1)} and {type(D2)} is not supported."
) )
return DiagMatrix(D1.val * D2.val)
return DiagMatrix(D1.val * D2)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix: def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
...@@ -125,13 +137,12 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix: ...@@ -125,13 +137,12 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
Examples Examples
-------- --------
>>> D1 = DiagMatrix(torch.arange(1, 4)) >>> D1 = diag(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13)) >>> D2 = diag(torch.arange(10, 13))
>>> D1 / D2 >>> D1 / D2
>>> D1/D2
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]), DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3)) shape=(3, 3))
>>> D1/2.5 >>> D1 / 2.5
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]), DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3)) shape=(3, 3))
""" """
...@@ -141,8 +152,8 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix: ...@@ -141,8 +152,8 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format( ), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape D1.shape, D2.shape
) )
return DiagMatrix(D1.val / D2.val) return diag(D1.val / D2.val, D1.shape)
return DiagMatrix(D1.val / D2) return diag(D1.val / D2, D1.shape)
def diag_rdiv(D1: float, D2: DiagMatrix): def diag_rdiv(D1: float, D2: DiagMatrix):
...@@ -178,7 +189,7 @@ def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix: ...@@ -178,7 +189,7 @@ def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
Examples Examples
-------- --------
>>> D1 = DiagMatrix(torch.arange(1, 4)) >>> D1 = diag(torch.arange(1, 4))
>>> pow(D1, 2) >>> pow(D1, 2)
DiagMatrix(val=tensor([1, 4, 9]), DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3)) shape=(3, 3))
......
"""DGL elementwise operators for sparse matrix module.""" """DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
__all__ = ["sp_add"] __all__ = ["sp_add"]
def spsp_add(A, B): def spsp_add(A, B):
""" Invoke C++ sparse library for addition """ """Invoke C++ sparse library for addition"""
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.spsp_add(A.c_sparse_matrix, B.c_sparse_matrix) torch.ops.dgl_sparse.spsp_add(A.c_sparse_matrix, B.c_sparse_matrix)
) )
def sp_add( def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> SparseMatrix:
"""Elementwise addition. """Elementwise addition.
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse matrix or diagonal matrix Sparse matrix
B : SparseMatrix or DiagMatrix B : SparseMatrix
Sparse matrix or diagonal matrix Sparse matrix
Returns Returns
------- -------
...@@ -34,46 +30,17 @@ def sp_add( ...@@ -34,46 +30,17 @@ def sp_add(
Examples Examples
-------- --------
Case 1: Add two sparse matrices of same sparsity structure
>>> rowA = torch.tensor([1, 0, 2]) >>> row = torch.tensor([1, 0, 2])
>>> colA = torch.tensor([0, 3, 2]) >>> col = torch.tensor([0, 3, 2])
>>> valA = torch.tensor([10, 20, 30]) >>> val = torch.tensor([10, 20, 30])
>>> A = SparseMatrix(rowA, colA, valA, shape=(3, 4)) >>> A = create_from_coo(row, col, val, shape=(3, 4))
>>> A + A >>> A + A
SparseMatrix(indices=tensor([[0, 1, 2], SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]), [3, 0, 2]]),
values=tensor([40, 20, 60]), values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
>>> w = torch.arange(1, len(rowA)+1)
>>> A + A(w)
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([21, 12, 33]),
shape=(3, 4), nnz=3)
Case 2: Add two sparse matrices of different sparsity structure
>>> rowB = torch.tensor([1, 2, 0, 2, 1])
>>> colB = torch.tensor([0, 2, 1, 3, 3])
>>> valB = torch.tensor([1, 2, 3, 4, 5])
>>> B = SparseMatrix(rowB, colB, valB, shape=(3 ,4))
>>> A + B
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2, 2],
[1, 3, 0, 3, 2, 3]]),
values=tensor([ 3, 20, 11, 5, 32, 4]),
shape=(3, 4), nnz=6)
Case 3: Add sparse matrix and diagonal matrix
>>> D = diag(torch.arange(2, 5), shape=A.shape)
>>> A + D
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 3, 0, 1, 2]]),
values=tensor([ 2, 20, 10, 3, 34]),
shape=(3, 4), nnz=5)
""" """
B = B.as_sparse() if isinstance(B, DiagMatrix) else B
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix): if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
return spsp_add(A, B) return spsp_add(A, B)
raise RuntimeError( raise RuntimeError(
......
import operator import operator
import sys
import backend as F
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import sys
from dgl.mock_sparse2 import diag from dgl.mock_sparse2 import diag
# TODO(#4818): Skipping tests on win. # TODO(#4818): Skipping tests on win.
...@@ -21,23 +22,38 @@ def all_close_sparse(A, B): ...@@ -21,23 +22,38 @@ def all_close_sparse(A, B):
"op", [operator.add, operator.sub, operator.mul, operator.truediv] "op", [operator.add, operator.sub, operator.mul, operator.truediv]
) )
def test_diag_op_diag(op): def test_diag_op_diag(op):
D1 = diag(torch.arange(1, 4)) ctx = F.ctx()
D2 = diag(torch.arange(10, 13)) shape = (3, 4)
assert np.allclose(op(D1, D2).val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4) D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
result = op(D1, D2)
assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
assert result.shape == D1.shape
@pytest.mark.parametrize("v_scalar", [2, 2.5]) @pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_diag_op_scalar(v_scalar): def test_diag_op_scalar(v_scalar):
D1 = diag(torch.arange(1, 50)) ctx = F.ctx()
assert np.allclose( shape = (3, 4)
D1.val * v_scalar, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4 D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
)
assert np.allclose( # D * v
v_scalar * D1.val, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4 D2 = D1 * v_scalar
) assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert np.allclose( assert D1.shape == D2.shape
D1.val / v_scalar, (D1 / v_scalar).val, rtol=1e-4, atol=1e-4
) # v * D
assert np.allclose( D2 = v_scalar * D1
pow(D1.val, v_scalar), pow(D1, v_scalar).val, rtol=1e-4, atol=1e-4 assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
) assert D1.shape == D2.shape
# D / v
D2 = D1 / v_scalar
assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D ^ v
D1 = diag(torch.arange(1, 4).to(ctx))
D2 = D1 ** v_scalar
assert torch.allclose(D1.val ** v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment