Unverified Commit 8abf9d54 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Use NotImplemented to let Python dispatch types (#5160)

* use NotImplemented

* add TypeError tests

* format

* lint

* fix type

* fix docstring

* lint

* remove redundant condition
parent 751b4c26
...@@ -30,23 +30,22 @@ def diag_add( ...@@ -30,23 +30,22 @@ def diag_add(
DiagMatrix(val=tensor([11, 13, 15]), DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3)) shape=(3, 3))
""" """
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix): if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, ( assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 " "The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match." f"{D1.shape} and D2 {D2.shape} must match."
) )
return diag(D1.val + D2.val, D1.shape) return diag(D1.val + D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, SparseMatrix): elif isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, ( assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 " "The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match." f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
) )
D1 = D1.as_sparse() D1 = D1.as_sparse()
return D1 + D2 return D1 + D2
raise RuntimeError( # Python falls back to D2.__radd__(D1) then TypeError when NotImplemented
"Elementwise addition between " # is returned.
f"{type(D1)} and {type(D2)} is not supported." return NotImplemented
)
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
...@@ -72,27 +71,24 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: ...@@ -72,27 +71,24 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
DiagMatrix(val=tensor([-9, -9, -9]), DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3)) shape=(3, 3))
""" """
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix): if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, ( assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 " "The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match." f"{D1.shape} and D2 {D2.shape} must match."
) )
return diag(D1.val - D2.val, D1.shape) return diag(D1.val - D2.val, D1.shape)
raise RuntimeError( # Python falls back to D2.__rsub__(D1) then TypeError when NotImplemented
"Elementwise subtraction between " # is returned.
f"{type(D1)} and {type(D2)} is not supported." return NotImplemented
)
def diag_mul( def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
D1: Union[DiagMatrix, float, int], D2: Union[DiagMatrix, float, int]
) -> DiagMatrix:
"""Elementwise multiplication """Elementwise multiplication
Parameters Parameters
---------- ----------
D1 : DiagMatrix or float or int D1 : DiagMatrix
Diagonal matrix or scalar value Diagonal matrix
D2 : DiagMatrix or float or int D2 : DiagMatrix or float or int
Diagonal matrix or scalar value Diagonal matrix or scalar value
...@@ -111,21 +107,18 @@ def diag_mul( ...@@ -111,21 +107,18 @@ def diag_mul(
DiagMatrix(val=tensor([2, 4, 6]), DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3)) shape=(3, 3))
""" """
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix): if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, ( assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 " "The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match." f"{D1.shape} and D2 {D2.shape} must match."
) )
return diag(D1.val * D2.val, D1.shape) return diag(D1.val * D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, (float, int)): elif isinstance(D2, (float, int)):
return diag(D1.val * D2, D1.shape) return diag(D1.val * D2, D1.shape)
elif isinstance(D1, (float, int)) and isinstance(D2, DiagMatrix): else:
return diag(D1 * D2.val, D2.shape) # Python falls back to D2.__rmul__(D1) then TypeError when
# NotImplemented is returned.
raise RuntimeError( return NotImplemented
"Elementwise multiplication between "
f"{type(D1)} and {type(D2)} is not supported."
)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
...@@ -165,28 +158,10 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -165,28 +158,10 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
elif isinstance(D2, (float, int)): elif isinstance(D2, (float, int)):
assert D2 != 0, "Division by zero is not allowed." assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape) return diag(D1.val / D2, D1.shape)
else:
raise RuntimeError( # Python falls back to D2.__rtruediv__(D1) then TypeError when
f"Elementwise division between a diagonal matrix and {type(D2)} is " # NotImplemented is returned.
"not supported." return NotImplemented
)
def diag_rdiv(D1: DiagMatrix, D2: Union[float, int]):
"""Function for preventing elementwise division of a scalar by a diagonal
matrix
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : float or int
Scalar value
"""
raise RuntimeError(
f"Elementwise division of {type(D2)} by a diagonal matrix is not "
"supported."
)
# pylint: disable=invalid-name # pylint: disable=invalid-name
...@@ -213,37 +188,17 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: ...@@ -213,37 +188,17 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
DiagMatrix(val=tensor([1, 4, 9]), DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3)) shape=(3, 3))
""" """
if isinstance(scalar, (float, int)): return (
return diag(D.val**scalar, D.shape) diag(D.val**scalar, D.shape)
if isinstance(scalar, (float, int))
raise RuntimeError( else NotImplemented
f"Raising a diagonal matrix to exponent {type(scalar)} is not allowed."
)
def diag_rpower(D: DiagMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a diagonal matrix exponent
Parameters
----------
D : DiagMatrix
Diagonal matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
f"Raising {type(scalar)} to a diagonal matrix component is not "
"allowed."
) )
DiagMatrix.__add__ = diag_add DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_sub
DiagMatrix.__mul__ = diag_mul DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div DiagMatrix.__truediv__ = diag_div
DiagMatrix.__rtruediv__ = diag_rdiv
DiagMatrix.__pow__ = diag_power DiagMatrix.__pow__ = diag_power
DiagMatrix.__rpow__ = diag_rpower
...@@ -3,7 +3,6 @@ from typing import Union ...@@ -3,7 +3,6 @@ from typing import Union
import torch import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import SparseMatrix, val_like
...@@ -14,15 +13,15 @@ def spsp_add(A, B): ...@@ -14,15 +13,15 @@ def spsp_add(A, B):
) )
def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix: def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition """Elementwise addition
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix Sparse matrix
B : DiagMatrix or SparseMatrix B : SparseMatrix
Diagonal matrix or sparse matrix Sparse matrix
Returns Returns
------- -------
...@@ -42,26 +41,19 @@ def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix: ...@@ -42,26 +41,19 @@ def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix:
values=tensor([40, 20, 60]), values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
if isinstance(B, DiagMatrix): # Python falls back to B.__radd__ then TypeError when NotImplemented is
B = B.as_sparse() # returned.
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix): return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented
return spsp_add(A, B)
raise RuntimeError(
"Elementwise addition between {} and {} is not "
"supported.".format(type(A), type(B))
)
def sp_mul( def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
A: Union[SparseMatrix, float, int], B: Union[SparseMatrix, float, int]
) -> SparseMatrix:
"""Elementwise multiplication """Elementwise multiplication
Parameters Parameters
---------- ----------
A : SparseMatrix or float or int A : SparseMatrix
First operand First operand
B : SparseMatrix or float or int B : float or int
Second operand Second operand
Returns Returns
...@@ -89,14 +81,13 @@ def sp_mul( ...@@ -89,14 +81,13 @@ def sp_mul(
values=tensor([2, 4, 6]), values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
if isinstance(A, SparseMatrix) and isinstance(B, (float, int)): if isinstance(B, (float, int)):
return val_like(A, A.val * B) return val_like(A, A.val * B)
elif isinstance(A, (float, int)) and isinstance(B, SparseMatrix): # Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
return val_like(B, A * B.val) # returned.
raise RuntimeError( # So this also handles the case of scalar * SparseMatrix since we set
"Elementwise multiplication between " # SparseMatrix.__rmul__ to be the same as SparseMatrix.__mul__.
f"{type(A)} and {type(B)} is not supported." return NotImplemented
)
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
...@@ -127,32 +118,16 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: ...@@ -127,32 +118,16 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
values=tensor([100, 400, 900]), values=tensor([100, 400, 900]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
if isinstance(scalar, (float, int)): # Python falls back to scalar.__rpow__ then TypeError when NotImplemented
return val_like(A, A.val**scalar) # is returned.
return (
raise RuntimeError( val_like(A, A.val**scalar)
f"Raising a sparse matrix to exponent {type(scalar)} is not allowed." if isinstance(scalar, (float, int))
) else NotImplemented
def sp_rpower(A: SparseMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a sparse matrix exponent
Parameters
----------
A : SparseMatrix
Sparse matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
f"Raising {type(scalar)} to a sparse matrix component is not allowed."
) )
SparseMatrix.__add__ = sp_add SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = sp_add
SparseMatrix.__mul__ = sp_mul SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__pow__ = sp_power SparseMatrix.__pow__ = sp_power
SparseMatrix.__rpow__ = sp_rpower
import operator
import sys import sys
import backend as F import backend as F
...@@ -31,6 +32,11 @@ def test_add_coo(val_shape): ...@@ -31,6 +32,11 @@ def test_add_coo(val_shape):
assert torch.allclose(dense_sum, sum1) assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2) assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)]) @pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csr(val_shape): def test_add_csr(val_shape):
...@@ -52,6 +58,11 @@ def test_add_csr(val_shape): ...@@ -52,6 +58,11 @@ def test_add_csr(val_shape):
assert torch.allclose(dense_sum, sum1) assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2) assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)]) @pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csc(val_shape): def test_add_csc(val_shape):
...@@ -73,6 +84,11 @@ def test_add_csc(val_shape): ...@@ -73,6 +84,11 @@ def test_add_csc(val_shape):
assert torch.allclose(dense_sum, sum1) assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2) assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)]) @pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_diag(val_shape): def test_add_diag(val_shape):
...@@ -112,3 +128,20 @@ def test_add_sparse_diag(val_shape): ...@@ -112,3 +128,20 @@ def test_add_sparse_diag(val_shape):
assert torch.allclose(dense_sum, sum2) assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_sum, sum3) assert torch.allclose(dense_sum, sum3)
assert torch.allclose(dense_sum, sum4) assert torch.allclose(dense_sum, sum4)
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape).to(ctx)
A = from_coo(row, col, val)
shape = (3, 4)
D = diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
with pytest.raises(TypeError):
getattr(operator, op)(A, D)
with pytest.raises(TypeError):
getattr(operator, op)(D, A)
...@@ -60,3 +60,13 @@ def test_diag_op_scalar(v_scalar): ...@@ -60,3 +60,13 @@ def test_diag_op_scalar(v_scalar):
D2 = power(D1, v_scalar) D2 = power(D1, v_scalar)
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4) assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape assert D1.shape == D2.shape
with pytest.raises(TypeError):
D1 + v_scalar
with pytest.raises(TypeError):
v_scalar + D1
with pytest.raises(TypeError):
D1 - v_scalar
with pytest.raises(TypeError):
v_scalar - D1
...@@ -62,3 +62,23 @@ def test_pow(val_shape): ...@@ -62,3 +62,23 @@ def test_pow(val_shape):
new_row, new_col = A_new.coo() new_row, new_col = A_new.coo()
assert torch.allclose(new_row, row) assert torch.allclose(new_row, row)
assert torch.allclose(new_col, col) assert torch.allclose(new_col, col)
@pytest.mark.parametrize("op", ["add", "sub"])
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_error_op_scalar(op, v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(len(row)).to(ctx)
A = from_coo(row, col, val, shape=(3, 4))
with pytest.raises(TypeError):
A + v_scalar
with pytest.raises(TypeError):
v_scalar + A
with pytest.raises(TypeError):
A - v_scalar
with pytest.raises(TypeError):
v_scalar - A
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment