Unverified Commit 5f5db2df authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Make functions compatible with PyTorch scalar tensors (#5163)

* use NotImplemented

* format

* extend to pytorch scalar

* reformat

* reformat

* lint
parent 1d1b08b0
......@@ -4,6 +4,7 @@ from typing import Union
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix
from .utils import Scalar
__all__ = ["add", "sub", "mul", "div", "power"]
......@@ -95,8 +96,8 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:
def mul(
A: Union[SparseMatrix, DiagMatrix, float, int],
B: Union[SparseMatrix, DiagMatrix, float, int],
A: Union[SparseMatrix, DiagMatrix, Scalar],
B: Union[SparseMatrix, DiagMatrix, Scalar],
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A * B``.
......@@ -115,9 +116,9 @@ def mul(
Parameters
----------
A : SparseMatrix or DiagMatrix or float or int
A : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value
B : SparseMatrix or DiagMatrix or float or int
B : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value
Returns
......@@ -151,7 +152,7 @@ def mul(
def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, float, int]
A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``.
......@@ -172,7 +173,7 @@ def div(
----------
A : DiagMatrix
Diagonal matrix
B : DiagMatrix or float or int
B : DiagMatrix or Scalar
Diagonal matrix or scalar value
Returns
......@@ -197,7 +198,7 @@ def div(
def power(
A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int]
A: Union[SparseMatrix, DiagMatrix], scalar: Scalar
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A ** scalar``.
......@@ -218,7 +219,7 @@ def power(
----------
A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
scalar : float or int
scalar : Scalar
Exponent
Returns
......
......@@ -3,6 +3,7 @@ from typing import Union
from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix
from .utils import is_scalar, Scalar
def diag_add(
......@@ -82,14 +83,14 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
return NotImplemented
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise multiplication
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or float or int
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value
Returns
......@@ -113,7 +114,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val * D2.val, D1.shape)
elif isinstance(D2, (float, int)):
elif is_scalar(D2):
return diag(D1.val * D2, D1.shape)
else:
# Python falls back to D2.__rmul__(D1) then TypeError when
......@@ -121,7 +122,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
return NotImplemented
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise division of a diagonal matrix by a diagonal matrix or a
scalar
......@@ -129,7 +130,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or float or int
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix,
division is only applied to the diagonal elements.
......@@ -155,7 +156,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
"must match."
)
return diag(D1.val / D2.val, D1.shape)
elif isinstance(D2, (float, int)):
elif is_scalar(D2):
assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape)
else:
......@@ -165,7 +166,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
# pylint: disable=invalid-name
def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix:
"""Take the power of each nonzero element and return a diagonal matrix with
the result.
......@@ -173,7 +174,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
----------
D : DiagMatrix
Diagonal matrix
scalar : float or int
scalar : Scalar
Exponent
Returns
......@@ -189,9 +190,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
shape=(3, 3))
"""
return (
diag(D.val**scalar, D.shape)
if isinstance(scalar, (float, int))
else NotImplemented
diag(D.val**scalar, D.shape) if is_scalar(scalar) else NotImplemented
)
......
"""DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch
from .sparse_matrix import SparseMatrix, val_like
from .utils import is_scalar, Scalar
def spsp_add(A, B):
......@@ -46,14 +45,14 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise multiplication
Parameters
----------
A : SparseMatrix
First operand
B : float or int
B : Scalar
Second operand
Returns
......@@ -81,7 +80,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3)
"""
if isinstance(B, (float, int)):
if is_scalar(B):
return val_like(A, A.val * B)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
......@@ -90,7 +89,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
return NotImplemented
def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with
the result.
......@@ -120,11 +119,7 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
"""
# Python falls back to scalar.__rpow__ then TypeError when NotImplemented
# is returned.
return (
val_like(A, A.val**scalar)
if isinstance(scalar, (float, int))
else NotImplemented
)
return val_like(A, A.val**scalar) if is_scalar(scalar) else NotImplemented
SparseMatrix.__add__ = sp_add
......
"""Utilities for DGL sparse module."""
from numbers import Number
from typing import Union
import torch
def is_scalar(x):
"""Check if the input is a scalar."""
return isinstance(x, Number) or (torch.is_tensor(x) and x.dim() == 0)
# Scalar type annotation
Scalar = Union[Number, torch.Tensor]
......@@ -29,7 +29,9 @@ def test_diag_op_diag(op):
assert result.shape == D1.shape
@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_diag_op_scalar(v_scalar):
ctx = F.ctx()
shape = (3, 4)
......
......@@ -20,7 +20,9 @@ def all_close_sparse(A, row, col, val, shape):
assert A.shape == shape
@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_mul_scalar(v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
......@@ -65,7 +67,9 @@ def test_pow(val_shape):
@pytest.mark.parametrize("op", ["add", "sub"])
@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_error_op_scalar(op, v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment