"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6ab2dd18a4d17d90c92409886ac22a02acf25d7d"
Unverified Commit 8abf9d54 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Use NotImplemented to let Python dispatch types (#5160)

* use NotImplemented

* add TypeError tests

* format

* lint

* fix type

* fix docstring

* lint

* remove redundant condition
parent 751b4c26
......@@ -30,23 +30,22 @@ def diag_add(
DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val + D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, SparseMatrix):
elif isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
)
D1 = D1.as_sparse()
return D1 + D2
raise RuntimeError(
"Elementwise addition between "
f"{type(D1)} and {type(D2)} is not supported."
)
# Python falls back to D2.__radd__(D1) then TypeError when NotImplemented
# is returned.
return NotImplemented
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
......@@ -72,27 +71,24 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val - D2.val, D1.shape)
raise RuntimeError(
"Elementwise subtraction between "
f"{type(D1)} and {type(D2)} is not supported."
)
# Python falls back to D2.__rsub__(D1) then TypeError when NotImplemented
# is returned.
return NotImplemented
def diag_mul(
D1: Union[DiagMatrix, float, int], D2: Union[DiagMatrix, float, int]
) -> DiagMatrix:
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
"""Elementwise multiplication
Parameters
----------
D1 : DiagMatrix or float or int
Diagonal matrix or scalar value
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or float or int
Diagonal matrix or scalar value
......@@ -111,21 +107,18 @@ def diag_mul(
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val * D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, (float, int)):
elif isinstance(D2, (float, int)):
return diag(D1.val * D2, D1.shape)
elif isinstance(D1, (float, int)) and isinstance(D2, DiagMatrix):
return diag(D1 * D2.val, D2.shape)
raise RuntimeError(
"Elementwise multiplication between "
f"{type(D1)} and {type(D2)} is not supported."
)
else:
# Python falls back to D2.__rmul__(D1) then TypeError when
# NotImplemented is returned.
return NotImplemented
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
......@@ -165,28 +158,10 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
elif isinstance(D2, (float, int)):
assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape)
raise RuntimeError(
f"Elementwise division between a diagonal matrix and {type(D2)} is "
"not supported."
)
def diag_rdiv(D1: DiagMatrix, D2: Union[float, int]):
"""Function for preventing elementwise division of a scalar by a diagonal
matrix
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : float or int
Scalar value
"""
raise RuntimeError(
f"Elementwise division of {type(D2)} by a diagonal matrix is not "
"supported."
)
else:
# Python falls back to D2.__rtruediv__(D1) then TypeError when
# NotImplemented is returned.
return NotImplemented
# pylint: disable=invalid-name
......@@ -213,37 +188,17 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
if isinstance(scalar, (float, int)):
return diag(D.val**scalar, D.shape)
raise RuntimeError(
f"Raising a diagonal matrix to exponent {type(scalar)} is not allowed."
)
def diag_rpower(D: DiagMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a diagonal matrix exponent
Parameters
----------
D : DiagMatrix
Diagonal matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
f"Raising {type(scalar)} to a diagonal matrix component is not "
"allowed."
return (
diag(D.val**scalar, D.shape)
if isinstance(scalar, (float, int))
else NotImplemented
)
DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_sub
DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div
DiagMatrix.__rtruediv__ = diag_rdiv
DiagMatrix.__pow__ = diag_power
DiagMatrix.__rpow__ = diag_rpower
......@@ -3,7 +3,6 @@ from typing import Union
import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix, val_like
......@@ -14,15 +13,15 @@ def spsp_add(A, B):
)
def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix:
def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition
Parameters
----------
A : SparseMatrix
Sparse matrix
B : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
B : SparseMatrix
Sparse matrix
Returns
-------
......@@ -38,30 +37,23 @@ def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix:
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A + A
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3)
[3, 0, 2]]),
values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3)
"""
if isinstance(B, DiagMatrix):
B = B.as_sparse()
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
return spsp_add(A, B)
raise RuntimeError(
"Elementwise addition between {} and {} is not "
"supported.".format(type(A), type(B))
)
# Python falls back to B.__radd__ then TypeError when NotImplemented is
# returned.
return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_mul(
A: Union[SparseMatrix, float, int], B: Union[SparseMatrix, float, int]
) -> SparseMatrix:
def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
"""Elementwise multiplication
Parameters
----------
A : SparseMatrix or float or int
A : SparseMatrix
First operand
B : SparseMatrix or float or int
B : float or int
Second operand
Returns
......@@ -89,14 +81,13 @@ def sp_mul(
values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3)
"""
if isinstance(A, SparseMatrix) and isinstance(B, (float, int)):
if isinstance(B, (float, int)):
return val_like(A, A.val * B)
elif isinstance(A, (float, int)) and isinstance(B, SparseMatrix):
return val_like(B, A * B.val)
raise RuntimeError(
"Elementwise multiplication between "
f"{type(A)} and {type(B)} is not supported."
)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
# So this also handles the case of scalar * SparseMatrix since we set
# SparseMatrix.__rmul__ to be the same as SparseMatrix.__mul__.
return NotImplemented
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
......@@ -127,32 +118,16 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
values=tensor([100, 400, 900]),
shape=(3, 4), nnz=3)
"""
if isinstance(scalar, (float, int)):
return val_like(A, A.val**scalar)
raise RuntimeError(
f"Raising a sparse matrix to exponent {type(scalar)} is not allowed."
)
def sp_rpower(A: SparseMatrix, scalar: Union[float, int]):
"""Function for preventing raising a scalar to a sparse matrix exponent
Parameters
----------
A : SparseMatrix
Sparse matrix
scalar : float or int
Scalar
"""
raise RuntimeError(
f"Raising {type(scalar)} to a sparse matrix component is not allowed."
# Python falls back to scalar.__rpow__ then TypeError when NotImplemented
# is returned.
return (
val_like(A, A.val**scalar)
if isinstance(scalar, (float, int))
else NotImplemented
)
SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = sp_add
SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__pow__ = sp_power
SparseMatrix.__rpow__ = sp_rpower
import operator
import sys
import backend as F
......@@ -31,6 +32,11 @@ def test_add_coo(val_shape):
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csr(val_shape):
......@@ -52,6 +58,11 @@ def test_add_csr(val_shape):
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csc(val_shape):
......@@ -73,6 +84,11 @@ def test_add_csc(val_shape):
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
with pytest.raises(TypeError):
A + 2
with pytest.raises(TypeError):
2 + A
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_diag(val_shape):
......@@ -112,3 +128,20 @@ def test_add_sparse_diag(val_shape):
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_sum, sum3)
assert torch.allclose(dense_sum, sum4)
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape).to(ctx)
A = from_coo(row, col, val)
shape = (3, 4)
D = diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
with pytest.raises(TypeError):
getattr(operator, op)(A, D)
with pytest.raises(TypeError):
getattr(operator, op)(D, A)
......@@ -60,3 +60,13 @@ def test_diag_op_scalar(v_scalar):
D2 = power(D1, v_scalar)
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
with pytest.raises(TypeError):
D1 + v_scalar
with pytest.raises(TypeError):
v_scalar + D1
with pytest.raises(TypeError):
D1 - v_scalar
with pytest.raises(TypeError):
v_scalar - D1
......@@ -62,3 +62,23 @@ def test_pow(val_shape):
new_row, new_col = A_new.coo()
assert torch.allclose(new_row, row)
assert torch.allclose(new_col, col)
@pytest.mark.parametrize("op", ["add", "sub"])
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_error_op_scalar(op, v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(len(row)).to(ctx)
A = from_coo(row, col, val, shape=(3, 4))
with pytest.raises(TypeError):
A + v_scalar
with pytest.raises(TypeError):
v_scalar + A
with pytest.raises(TypeError):
A - v_scalar
with pytest.raises(TypeError):
v_scalar - A
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment