"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cd00ea1518995e2678fb372d0bf142da8d989faa"
Unverified Commit 5f5db2df authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Make functions compatible with PyTorch scalar tensors (#5163)

* use NotImplemented

* format

* extend to pytorch scalar

* reformat

* reformat

* lint
parent 1d1b08b0
...@@ -4,6 +4,7 @@ from typing import Union ...@@ -4,6 +4,7 @@ from typing import Union
from .diag_matrix import DiagMatrix from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
from .utils import Scalar
__all__ = ["add", "sub", "mul", "div", "power"] __all__ = ["add", "sub", "mul", "div", "power"]
...@@ -95,8 +96,8 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]: ...@@ -95,8 +96,8 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
def mul( def mul(
A: Union[SparseMatrix, DiagMatrix, float, int], A: Union[SparseMatrix, DiagMatrix, Scalar],
B: Union[SparseMatrix, DiagMatrix, float, int], B: Union[SparseMatrix, DiagMatrix, Scalar],
) -> Union[SparseMatrix, DiagMatrix]: ) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``, r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A * B``. equivalent to ``A * B``.
...@@ -115,9 +116,9 @@ def mul( ...@@ -115,9 +116,9 @@ def mul(
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix or float or int A : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value Sparse matrix or diagonal matrix or scalar value
B : SparseMatrix or DiagMatrix or float or int B : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value Sparse matrix or diagonal matrix or scalar value
Returns Returns
...@@ -151,7 +152,7 @@ def mul( ...@@ -151,7 +152,7 @@ def mul(
def div( def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, float, int] A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[DiagMatrix]: ) -> Union[DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``. to ``A / B``.
...@@ -172,7 +173,7 @@ def div( ...@@ -172,7 +173,7 @@ def div(
---------- ----------
A : DiagMatrix A : DiagMatrix
Diagonal matrix Diagonal matrix
B : DiagMatrix or float or int B : DiagMatrix or Scalar
Diagonal matrix or scalar value Diagonal matrix or scalar value
Returns Returns
...@@ -197,7 +198,7 @@ def div( ...@@ -197,7 +198,7 @@ def div(
def power( def power(
A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int] A: Union[SparseMatrix, DiagMatrix], scalar: Scalar
) -> Union[SparseMatrix, DiagMatrix]: ) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``, r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A ** scalar``. equivalent to ``A ** scalar``.
...@@ -218,7 +219,7 @@ def power( ...@@ -218,7 +219,7 @@ def power(
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix Sparse matrix or diagonal matrix
scalar : float or int scalar : Scalar
Exponent Exponent
Returns Returns
......
...@@ -3,6 +3,7 @@ from typing import Union ...@@ -3,6 +3,7 @@ from typing import Union
from .diag_matrix import diag, DiagMatrix from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
from .utils import is_scalar, Scalar
def diag_add( def diag_add(
...@@ -82,14 +83,14 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: ...@@ -82,14 +83,14 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
return NotImplemented return NotImplemented
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise multiplication """Elementwise multiplication
Parameters Parameters
---------- ----------
D1 : DiagMatrix D1 : DiagMatrix
Diagonal matrix Diagonal matrix
D2 : DiagMatrix or float or int D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value Diagonal matrix or scalar value
Returns Returns
...@@ -113,7 +114,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -113,7 +114,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
f"{D1.shape} and D2 {D2.shape} must match." f"{D1.shape} and D2 {D2.shape} must match."
) )
return diag(D1.val * D2.val, D1.shape) return diag(D1.val * D2.val, D1.shape)
elif isinstance(D2, (float, int)): elif is_scalar(D2):
return diag(D1.val * D2, D1.shape) return diag(D1.val * D2, D1.shape)
else: else:
# Python falls back to D2.__rmul__(D1) then TypeError when # Python falls back to D2.__rmul__(D1) then TypeError when
...@@ -121,7 +122,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -121,7 +122,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
return NotImplemented return NotImplemented
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise division of a diagonal matrix by a diagonal matrix or a """Elementwise division of a diagonal matrix by a diagonal matrix or a
scalar scalar
...@@ -129,7 +130,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -129,7 +130,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
---------- ----------
D1 : DiagMatrix D1 : DiagMatrix
Diagonal matrix Diagonal matrix
D2 : DiagMatrix or float or int D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix, Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix,
division is only applied to the diagonal elements. division is only applied to the diagonal elements.
...@@ -155,7 +156,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -155,7 +156,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
"must match." "must match."
) )
return diag(D1.val / D2.val, D1.shape) return diag(D1.val / D2.val, D1.shape)
elif isinstance(D2, (float, int)): elif is_scalar(D2):
assert D2 != 0, "Division by zero is not allowed." assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape) return diag(D1.val / D2, D1.shape)
else: else:
...@@ -165,7 +166,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ...@@ -165,7 +166,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
# pylint: disable=invalid-name # pylint: disable=invalid-name
def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix:
"""Take the power of each nonzero element and return a diagonal matrix with """Take the power of each nonzero element and return a diagonal matrix with
the result. the result.
...@@ -173,7 +174,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: ...@@ -173,7 +174,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
---------- ----------
D : DiagMatrix D : DiagMatrix
Diagonal matrix Diagonal matrix
scalar : float or int scalar : Scalar
Exponent Exponent
Returns Returns
...@@ -189,9 +190,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: ...@@ -189,9 +190,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
shape=(3, 3)) shape=(3, 3))
""" """
return ( return (
diag(D.val**scalar, D.shape) diag(D.val**scalar, D.shape) if is_scalar(scalar) else NotImplemented
if isinstance(scalar, (float, int))
else NotImplemented
) )
......
"""DGL elementwise operators for sparse matrix module.""" """DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch import torch
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import SparseMatrix, val_like
from .utils import is_scalar, Scalar
def spsp_add(A, B): def spsp_add(A, B):
...@@ -46,14 +45,14 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: ...@@ -46,14 +45,14 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise multiplication """Elementwise multiplication
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
First operand First operand
B : float or int B : Scalar
Second operand Second operand
Returns Returns
...@@ -81,7 +80,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: ...@@ -81,7 +80,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
values=tensor([2, 4, 6]), values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
if isinstance(B, (float, int)): if is_scalar(B):
return val_like(A, A.val * B) return val_like(A, A.val * B)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is # Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned. # returned.
...@@ -90,7 +89,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: ...@@ -90,7 +89,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
return NotImplemented return NotImplemented
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with """Take the power of each nonzero element and return a sparse matrix with
the result. the result.
...@@ -120,11 +119,7 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: ...@@ -120,11 +119,7 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
""" """
# Python falls back to scalar.__rpow__ then TypeError when NotImplemented # Python falls back to scalar.__rpow__ then TypeError when NotImplemented
# is returned. # is returned.
return ( return val_like(A, A.val**scalar) if is_scalar(scalar) else NotImplemented
val_like(A, A.val**scalar)
if isinstance(scalar, (float, int))
else NotImplemented
)
SparseMatrix.__add__ = sp_add SparseMatrix.__add__ = sp_add
......
"""Utilities for DGL sparse module."""
from numbers import Number
from typing import Union
import torch
def is_scalar(x):
"""Check if the input is a scalar."""
return isinstance(x, Number) or (torch.is_tensor(x) and x.dim() == 0)
# Scalar type annotation
Scalar = Union[Number, torch.Tensor]
...@@ -29,7 +29,9 @@ def test_diag_op_diag(op): ...@@ -29,7 +29,9 @@ def test_diag_op_diag(op):
assert result.shape == D1.shape assert result.shape == D1.shape
@pytest.mark.parametrize("v_scalar", [2, 2.5]) @pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_diag_op_scalar(v_scalar): def test_diag_op_scalar(v_scalar):
ctx = F.ctx() ctx = F.ctx()
shape = (3, 4) shape = (3, 4)
......
...@@ -20,7 +20,9 @@ def all_close_sparse(A, row, col, val, shape): ...@@ -20,7 +20,9 @@ def all_close_sparse(A, row, col, val, shape):
assert A.shape == shape assert A.shape == shape
@pytest.mark.parametrize("v_scalar", [2, 2.5]) @pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_mul_scalar(v_scalar): def test_mul_scalar(v_scalar):
ctx = F.ctx() ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx) row = torch.tensor([1, 0, 2]).to(ctx)
...@@ -65,7 +67,9 @@ def test_pow(val_shape): ...@@ -65,7 +67,9 @@ def test_pow(val_shape):
@pytest.mark.parametrize("op", ["add", "sub"]) @pytest.mark.parametrize("op", ["add", "sub"])
@pytest.mark.parametrize("v_scalar", [2, 2.5]) @pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_error_op_scalar(op, v_scalar): def test_error_op_scalar(op, v_scalar):
ctx = F.ctx() ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx) row = torch.tensor([1, 0, 2]).to(ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment