"examples/vscode:/vscode.git/clone" did not exist on "001d79371144612ebec69ec000b9eb25d8a6a818"
Unverified Commit 27010dbc authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Sparse matrix subtraction (#5164)



* use NotImplemented

* format

* sparse sub

* address comments

* lint

* Update elementwise_op_diag.py

* ugh
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 419fb815
......@@ -56,7 +56,9 @@ def add(
return A + B
def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
def sub(
A: Union[DiagMatrix, SparseMatrix], B: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
r"""Elementwise subtraction for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A - B``.
......@@ -65,32 +67,38 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | 🚫 |
| DiagMatrix | ✅ | | 🚫 |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | 🚫 |
| SparseMatrix | | | 🚫 |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters
----------
A : DiagMatrix
Diagonal matrix
B : DiagMatrix
Diagonal matrix
A : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
B : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix
Diagonal matrix
DiagMatrix or SparseMatrix
Diagonal matrix if both :attr:`A` and :attr:`B` are diagonal matrices,
sparse matrix otherwise
Examples
--------
>>> A = diag(torch.arange(1, 4))
>>> B = diag(torch.arange(10, 13))
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 1, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = from_coo(row, col, val)
>>> B = diag(torch.arange(1, 4))
>>> sub(A, B)
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 1, 0, 1, 2]]),
values=tensor([-1, 20, 10, -2, 27]),
shape=(3, 3), nnz=5)
"""
return A - B
......
......@@ -49,20 +49,22 @@ def diag_add(
return NotImplemented
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
def diag_sub(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise subtraction
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix
Diagonal matrix
DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix, same as D2
Examples
--------
......@@ -78,11 +80,46 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val - D2.val, D1.shape)
elif isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
)
D1 = D1.to_sparse()
return D1 - D2
# Python falls back to D2.__rsub__(D1) then TypeError when NotImplemented
# is returned.
return NotImplemented
def diag_rsub(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise subtraction in the opposite direction (``D2 - D1``)
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix, same as D2
Examples
--------
>>> D1 = diag(torch.arange(1, 4))
>>> D2 = diag(torch.arange(10, 13))
>>> D2 - D1
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
return -(D1 - D2)
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise multiplication
......@@ -197,6 +234,7 @@ def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix:
DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_rsub
DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div
......
......@@ -45,7 +45,43 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise subtraction
Parameters
----------
A : SparseMatrix
Sparse matrix
B : SparseMatrix
Sparse matrix
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([10, 20, 30])
>>> val2 = torch.tensor([5, 10, 15])
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> B = from_coo(row, col, val2, shape=(3, 4))
>>> A - B
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([10, 5, 15]),
shape=(3, 4), nnz=3)
"""
# Python falls back to B.__rsub__ then TypeError when NotImplemented is
# returned.
return spsp_add(A, -B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise multiplication
Parameters
......@@ -157,6 +193,7 @@ def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
SparseMatrix.__add__ = sp_add
SparseMatrix.__sub__ = sp_sub
SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__truediv__ = sp_div
......
......@@ -2,108 +2,125 @@ import operator
import sys
import backend as F
import dgl.sparse as dglsp
import pytest
import torch
from dgl.sparse import add, diag, from_coo, from_csc, from_csr
# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_coo(val_shape):
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_coo(val_shape, opname):
op = getattr(operator, opname)
func = getattr(dglsp, opname)
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
A = from_coo(row, col, val)
A = dglsp.from_coo(row, col, val)
row = torch.tensor([1, 0]).to(ctx)
col = torch.tensor([0, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
B = from_coo(row, col, val, shape=A.shape)
B = dglsp.from_coo(row, col, val, shape=A.shape)
sum1 = (A + B).to_dense()
sum2 = add(A, B).to_dense()
dense_sum = A.to_dense() + B.to_dense()
C1 = op(A, B).to_dense()
C2 = func(A, B).to_dense()
dense_C = op(A.to_dense(), B.to_dense())
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_C, C1)
assert torch.allclose(dense_C, C2)
with pytest.raises(TypeError):
A + 2
op(A, 2)
with pytest.raises(TypeError):
2 + A
op(2, A)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csr(val_shape):
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csr(val_shape, opname):
op = getattr(operator, opname)
func = getattr(dglsp, opname)
ctx = F.ctx()
indptr = torch.tensor([0, 1, 2, 3]).to(ctx)
indices = torch.tensor([3, 0, 2]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
A = from_csr(indptr, indices, val)
A = dglsp.from_csr(indptr, indices, val)
indptr = torch.tensor([0, 1, 2, 2]).to(ctx)
indices = torch.tensor([2, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
B = from_csr(indptr, indices, val, shape=A.shape)
B = dglsp.from_csr(indptr, indices, val, shape=A.shape)
sum1 = (A + B).to_dense()
sum2 = add(A, B).to_dense()
dense_sum = A.to_dense() + B.to_dense()
C1 = op(A, B).to_dense()
C2 = func(A, B).to_dense()
dense_C = op(A.to_dense(), B.to_dense())
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_C, C1)
assert torch.allclose(dense_C, C2)
with pytest.raises(TypeError):
A + 2
op(A, 2)
with pytest.raises(TypeError):
2 + A
op(2, A)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csc(val_shape):
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csc(val_shape, opname):
op = getattr(operator, opname)
func = getattr(dglsp, opname)
ctx = F.ctx()
indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)
indices = torch.tensor([1, 2, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
A = from_csc(indptr, indices, val)
A = dglsp.from_csc(indptr, indices, val)
indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)
indices = torch.tensor([1, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
B = from_csc(indptr, indices, val, shape=A.shape)
B = dglsp.from_csc(indptr, indices, val, shape=A.shape)
sum1 = (A + B).to_dense()
sum2 = add(A, B).to_dense()
dense_sum = A.to_dense() + B.to_dense()
C1 = op(A, B).to_dense()
C2 = func(A, B).to_dense()
dense_C = op(A.to_dense(), B.to_dense())
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_C, C1)
assert torch.allclose(dense_C, C2)
with pytest.raises(TypeError):
A + 2
op(A, 2)
with pytest.raises(TypeError):
2 + A
op(2, A)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_diag(val_shape):
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_diag(val_shape, opname):
op = getattr(operator, opname)
func = getattr(dglsp, opname)
ctx = F.ctx()
shape = (3, 4)
val_shape = (shape[0],) + val_shape
D1 = diag(torch.randn(val_shape).to(ctx), shape=shape)
D2 = diag(torch.randn(val_shape).to(ctx), shape=shape)
D1 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
D2 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
sum1 = (D1 + D2).to_dense()
sum2 = add(D1, D2).to_dense()
dense_sum = D1.to_dense() + D2.to_dense()
C1 = op(D1, D2).to_dense()
C2 = func(D1, D2).to_dense()
dense_C = op(D1.to_dense(), D2.to_dense())
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_C, C1)
assert torch.allclose(dense_C, C2)
with pytest.raises(TypeError):
op(D1, 2)
with pytest.raises(TypeError):
op(2, D1)
@pytest.mark.parametrize("val_shape", [(), (2,)])
......@@ -112,16 +129,16 @@ def test_add_sparse_diag(val_shape):
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
A = from_coo(row, col, val)
A = dglsp.from_coo(row, col, val)
shape = (3, 4)
val_shape = (shape[0],) + val_shape
D = diag(torch.randn(val_shape).to(ctx), shape=shape)
D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
sum1 = (A + D).to_dense()
sum2 = (D + A).to_dense()
sum3 = add(A, D).to_dense()
sum4 = add(D, A).to_dense()
sum3 = dglsp.add(A, D).to_dense()
sum4 = dglsp.add(D, A).to_dense()
dense_sum = A.to_dense() + D.to_dense()
assert torch.allclose(dense_sum, sum1)
......@@ -130,16 +147,40 @@ def test_add_sparse_diag(val_shape):
assert torch.allclose(dense_sum, sum4)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_sub_sparse_diag(val_shape):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
A = dglsp.from_coo(row, col, val)
shape = (3, 4)
val_shape = (shape[0],) + val_shape
D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
diff1 = (A - D).to_dense()
diff2 = (D - A).to_dense()
diff3 = dglsp.sub(A, D).to_dense()
diff4 = dglsp.sub(D, A).to_dense()
dense_diff = A.to_dense() - D.to_dense()
assert torch.allclose(dense_diff, diff1)
assert torch.allclose(dense_diff, -diff2)
assert torch.allclose(dense_diff, diff3)
assert torch.allclose(dense_diff, -diff4)
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape).to(ctx)
A = from_coo(row, col, val)
A = dglsp.from_coo(row, col, val)
shape = (3, 4)
D = diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
D = dglsp.diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
with pytest.raises(TypeError):
getattr(operator, op)(A, D)
......
......@@ -18,8 +18,9 @@ def all_close_sparse(A, B):
assert A.shape == B.shape
@pytest.mark.parametrize("op", [operator.sub, operator.mul, operator.truediv])
def test_diag_op_diag(op):
@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
def test_diag_op_diag(opname):
op = getattr(operator, opname)
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment