Unverified Commit e01580d3 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Misc fix for existing elementwise ops (#5003)



* Update

* Update

* Update

* Update

* update

* lint

* CI

* CI

* CI
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent aad3bd04
"""DGL elementwise operators for diagonal matrix module."""
from typing import Union
from .diag_matrix import DiagMatrix
from .diag_matrix import diag, DiagMatrix
__all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
......@@ -23,18 +23,22 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 = diag(torch.arange(1, 4))
>>> D2 = diag(torch.arange(10, 13))
>>> D1 + D2
DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" " D2 {} must match.".format(
D1.shape, D2.shape
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val + D2.val, D1.shape)
raise RuntimeError(
"Elementwise addition between "
f"{type(D1)} and {type(D2)} is not supported."
)
return DiagMatrix(D1.val + D2.val)
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
......@@ -54,58 +58,66 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 -D2
>>> D1 = diag(torch.arange(1, 4))
>>> D2 = diag(torch.arange(10, 13))
>>> D1 - D2
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val - D2.val, D1.shape)
raise RuntimeError(
"Elementwise subtraction between "
f"{type(D1)} and {type(D2)} is not supported."
)
return DiagMatrix(D1.val - D2.val)
def diag_mul(
D1: Union[DiagMatrix, float], D2: Union[DiagMatrix, float]
D1: Union[DiagMatrix, float, int], D2: Union[DiagMatrix, float, int]
) -> DiagMatrix:
"""Elementwise multiplication.
Parameters
----------
D1 : DiagMatrix or scalar
Diagonal matrix or scalar value
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
Parameters
----------
D1 : DiagMatrix or float or int
Diagonal matrix or scalar value
D2 : DiagMatrix or float or int
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
DiagMatrix(val=tensor([10, 22, 36]),
shape=(3, 3))
>>> D1 * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3))
>>> 2 * D1
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D = diag(torch.arange(1, 4))
>>> D * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3))
>>> 2 * D
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return DiagMatrix(D1.val * D2.val)
return DiagMatrix(D1.val * D2)
return diag(D1.val * D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, (float, int)):
return diag(D1.val * D2, D1.shape)
elif isinstance(D1, (float, int)) and isinstance(D2, DiagMatrix):
return diag(D1 * D2.val, D2.shape)
raise RuntimeError(
"Elementwise multiplication between "
f"{type(D1)} and {type(D2)} is not supported."
)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
......@@ -125,13 +137,12 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 = diag(torch.arange(1, 4))
>>> D2 = diag(torch.arange(10, 13))
>>> D1 / D2
>>> D1/D2
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3))
>>> D1/2.5
>>> D1 / 2.5
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3))
"""
......@@ -141,8 +152,8 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val / D2.val)
return DiagMatrix(D1.val / D2)
return diag(D1.val / D2.val, D1.shape)
return diag(D1.val / D2, D1.shape)
def diag_rdiv(D1: float, D2: DiagMatrix):
......@@ -178,7 +189,7 @@ def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D1 = diag(torch.arange(1, 4))
>>> pow(D1, 2)
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
......
"""DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix
__all__ = ["sp_add"]
def spsp_add(A, B):
""" Invoke C++ sparse library for addition """
"""Invoke C++ sparse library for addition"""
return SparseMatrix(
torch.ops.dgl_sparse.spsp_add(A.c_sparse_matrix, B.c_sparse_matrix)
)
def sp_add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> SparseMatrix:
def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition.
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
B : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
A : SparseMatrix
Sparse matrix
B : SparseMatrix
Sparse matrix
Returns
-------
......@@ -34,46 +30,17 @@ def sp_add(
Examples
--------
Case 1: Add two sparse matrices of same sparsity structure
>>> rowA = torch.tensor([1, 0, 2])
>>> colA = torch.tensor([0, 3, 2])
>>> valA = torch.tensor([10, 20, 30])
>>> A = SparseMatrix(rowA, colA, valA, shape=(3, 4))
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = create_from_coo(row, col, val, shape=(3, 4))
>>> A + A
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3)
>>> w = torch.arange(1, len(rowA)+1)
>>> A + A(w)
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([21, 12, 33]),
shape=(3, 4), nnz=3)
Case 2: Add two sparse matrices of different sparsity structure
>>> rowB = torch.tensor([1, 2, 0, 2, 1])
>>> colB = torch.tensor([0, 2, 1, 3, 3])
>>> valB = torch.tensor([1, 2, 3, 4, 5])
>>> B = SparseMatrix(rowB, colB, valB, shape=(3 ,4))
>>> A + B
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2, 2],
[1, 3, 0, 3, 2, 3]]),
values=tensor([ 3, 20, 11, 5, 32, 4]),
shape=(3, 4), nnz=6)
Case 3: Add sparse matrix and diagonal matrix
>>> D = diag(torch.arange(2, 5), shape=A.shape)
>>> A + D
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 3, 0, 1, 2]]),
values=tensor([ 2, 20, 10, 3, 34]),
shape=(3, 4), nnz=5)
"""
B = B.as_sparse() if isinstance(B, DiagMatrix) else B
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
return spsp_add(A, B)
raise RuntimeError(
......
import operator
import sys
import backend as F
import numpy as np
import pytest
import torch
import sys
from dgl.mock_sparse2 import diag
# TODO(#4818): Skipping tests on win.
......@@ -21,23 +22,38 @@ def all_close_sparse(A, B):
"op", [operator.add, operator.sub, operator.mul, operator.truediv]
)
def test_diag_op_diag(op):
D1 = diag(torch.arange(1, 4))
D2 = diag(torch.arange(10, 13))
assert np.allclose(op(D1, D2).val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
result = op(D1, D2)
assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
assert result.shape == D1.shape
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_diag_op_scalar(v_scalar):
D1 = diag(torch.arange(1, 50))
assert np.allclose(
D1.val * v_scalar, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
v_scalar * D1.val, (D1 * v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
D1.val / v_scalar, (D1 / v_scalar).val, rtol=1e-4, atol=1e-4
)
assert np.allclose(
pow(D1.val, v_scalar), pow(D1, v_scalar).val, rtol=1e-4, atol=1e-4
)
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
# D * v
D2 = D1 * v_scalar
assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# v * D
D2 = v_scalar * D1
assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D / v
D2 = D1 / v_scalar
assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D ^ v
D1 = diag(torch.arange(1, 4).to(ctx))
D2 = D1 ** v_scalar
assert torch.allclose(D1.val ** v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment