Unverified Commit 92e383c3 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Remove the Python DiagMatrix class (#5444)

* [Sparse] Remove the Python DiagMatrix class

* Update

* Update

* update

* Update
parent f5ddb114
...@@ -32,7 +32,8 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -32,7 +32,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("csc", &SparseMatrix::CSCTensors) .def("csc", &SparseMatrix::CSCTensors)
.def("transpose", &SparseMatrix::Transpose) .def("transpose", &SparseMatrix::Transpose)
.def("coalesce", &SparseMatrix::Coalesce) .def("coalesce", &SparseMatrix::Coalesce)
.def("has_duplicate", &SparseMatrix::HasDuplicate); .def("has_duplicate", &SparseMatrix::HasDuplicate)
.def("is_diag", &SparseMatrix::HasDiag);
m.def("from_coo", &SparseMatrix::FromCOO) m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR) .def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC) .def("from_csc", &SparseMatrix::FromCSC)
......
...@@ -130,14 +130,18 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike( ...@@ -130,14 +130,18 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
TORCH_CHECK( TORCH_CHECK(
mat->value().device() == value.device(), "The device of the ", mat->value().device() == value.device(), "The device of the ",
"old values and the new values must be the same."); "old values and the new values must be the same.");
auto shape = mat->shape(); const auto& shape = mat->shape();
if (mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);
}
if (mat->HasCOO()) { if (mat->HasCOO()) {
return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape); return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);
} else if (mat->HasCSR()) { }
if (mat->HasCSR()) {
return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape); return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);
} else {
return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
} }
TORCH_CHECK(mat->HasCSC(), "Invalid sparse format for ValLike.")
return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);
} }
std::shared_ptr<COO> SparseMatrix::COOPtr() { std::shared_ptr<COO> SparseMatrix::COOPtr() {
...@@ -195,7 +199,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const { ...@@ -195,7 +199,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
auto shape = shape_; auto shape = shape_;
std::swap(shape[0], shape[1]); std::swap(shape[0], shape[1]);
auto value = value_; auto value = value_;
if (HasCOO()) { if (HasDiag()) {
return SparseMatrix::FromDiag(value, shape);
} else if (HasCOO()) {
auto coo = COOTranspose(coo_); auto coo = COOTranspose(coo_);
return SparseMatrix::FromCOOPointer(coo, value, shape); return SparseMatrix::FromCOOPointer(coo, value, shape);
} else if (HasCSR()) { } else if (HasCSR()) {
......
...@@ -5,17 +5,14 @@ import sys ...@@ -5,17 +5,14 @@ import sys
import torch import torch
from .._ffi import libinfo from .._ffi import libinfo
from .diag_matrix import *
from .elementwise_op import * from .elementwise_op import *
from .elementwise_op_diag import *
from .elementwise_op_sp import * from .elementwise_op_sp import *
from .matmul import * from .matmul import *
from .reduction import * # pylint: disable=W0622 from .reduction import * # pylint: disable=W0622
from .sddmm import * from .sddmm import *
from .softmax import * from .softmax import *
from .sparse_matrix import * from .sparse_matrix import *
from .unary_op_diag import * from .unary_op import *
from .unary_op_sp import *
def load_dgl_sparse(): def load_dgl_sparse():
......
"""DGL diagonal matrix module."""
# pylint: disable= invalid-name
from typing import Optional, Tuple
import torch
from .sparse_matrix import from_coo, SparseMatrix
class DiagMatrix:
r"""Class for diagonal matrix.
Parameters
----------
val : torch.Tensor
Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``
shape : tuple[int, int], optional
If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,
otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``
"""
def __init__(
self, val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
):
len_val = len(val)
if shape is not None:
assert len_val == min(shape), (
f"Expect len(val) to be min(shape), got {len_val} for len(val)"
"and {shape} for shape."
)
else:
shape = (len_val, len_val)
self._val = val
self._shape = shape
def __repr__(self):
return _diag_matrix_str(self)
@property
def val(self) -> torch.Tensor:
"""Returns the values of the non-zero elements.
Returns
-------
torch.Tensor
Values of the non-zero elements
"""
return self._val
@property
def shape(self) -> Tuple[int]:
"""Returns the shape of the diagonal matrix.
Returns
-------
Tuple[int]
The shape of the diagonal matrix
"""
return self._shape
@property
def nnz(self) -> int:
"""Returns the number of non-zero elements in the diagonal matrix.
Returns
-------
int
The number of non-zero elements in the diagonal matrix
"""
return self.val.shape[0]
@property
def dtype(self) -> torch.dtype:
"""Returns the data type of the diagonal matrix.
Returns
-------
torch.dtype
Data type of the diagonal matrix
"""
return self.val.dtype
@property
def device(self) -> torch.device:
"""Returns the device the diagonal matrix is on.
Returns
-------
torch.device
The device the diagonal matrix is on
"""
return self.val.device
def to_sparse(self) -> SparseMatrix:
"""Returns a copy in sparse matrix format of the diagonal matrix.
Returns
-------
SparseMatrix
The copy in sparse matrix format
Examples
--------
>>> import torch
>>> val = torch.ones(5)
>>> D = dglsp.diag(val)
>>> D.to_sparse()
SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5), nnz=5)
"""
row = col = torch.arange(len(self.val)).to(self.device)
return from_coo(row=row, col=col, val=self.val, shape=self.shape)
def to_dense(self) -> torch.Tensor:
"""Returns a copy in dense matrix format of the diagonal matrix.
Returns
-------
torch.Tensor
The copy in dense matrix format
"""
val = self.val
device = self.device
shape = self.shape + val.shape[1:]
mat = torch.zeros(shape, device=device, dtype=self.dtype)
row = col = torch.arange(len(val)).to(device)
mat[row, col] = val
return mat
def t(self):
"""Alias of :meth:`transpose()`"""
return self.transpose()
@property
def T(self): # pylint: disable=C0103
"""Alias of :meth:`transpose()`"""
return self.transpose()
def transpose(self):
"""Returns a matrix that is a transposed version of the diagonal matrix.
Returns
-------
DiagMatrix
The transpose of the matrix
Examples
--------
>>> val = torch.arange(1, 5).float()
>>> D = dglsp.diag(val, shape=(4, 5))
>>> D.transpose()
DiagMatrix(val=tensor([1., 2., 3., 4.]),
shape=(5, 4))
"""
return DiagMatrix(self.val, self.shape[::-1])
def to(self, device=None, dtype=None):
"""Performs matrix dtype and/or device conversion. If the target device
and dtype are already in use, the original matrix will be returned.
Parameters
----------
device : torch.device, optional
The target device of the matrix if provided, otherwise the current
device will be used
dtype : torch.dtype, optional
The target data type of the matrix values if provided, otherwise the
current data type will be used
Returns
-------
DiagMatrix
The converted matrix
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.to(device="cuda:0", dtype=torch.int32)
DiagMatrix(values=tensor([1, 1], device='cuda:0', dtype=torch.int32),
shape=(2, 2))
"""
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
if device == self.device and dtype == self.dtype:
return self
return diag(self.val.to(device=device, dtype=dtype), self.shape)
def cuda(self):
"""Moves the matrix to GPU. If the matrix is already on GPU, the
original matrix will be returned. If multiple GPU devices exist,
``cuda:0`` will be selected.
Returns
-------
DiagMatrix
The matrix on GPU
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.cuda()
DiagMatrix(values=tensor([1., 1.], device='cuda:0'),
shape=(2, 2))
"""
return self.to(device="cuda")
def cpu(self):
"""Moves the matrix to CPU. If the matrix is already on CPU, the
original matrix will be returned.
Returns
-------
DiagMatrix
The matrix on CPU
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.cpu()
DiagMatrix(values=tensor([1., 1.]),
shape=(2, 2))
"""
return self.to(device="cpu")
def float(self):
"""Converts the matrix values to float32 data type. If the matrix
already uses float data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with float values
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.float()
DiagMatrix(values=tensor([1., 1.]),
shape=(2, 2))
"""
return self.to(dtype=torch.float)
def double(self):
"""Converts the matrix values to double data type. If the matrix already
uses double data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with double values
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.double()
DiagMatrix(values=tensor([1., 1.], dtype=torch.float64),
shape=(2, 2))
"""
return self.to(dtype=torch.double)
def int(self):
"""Converts the matrix values to int32 data type. If the matrix already
uses int data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with int values
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.int()
DiagMatrix(values=tensor([1, 1], dtype=torch.int32),
shape=(2, 2))
"""
return self.to(dtype=torch.int)
def long(self):
"""Converts the matrix values to long data type. If the matrix already
uses long data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with long values
Examples
--------
>>> val = torch.ones(2)
>>> D = dglsp.diag(val)
>>> D.long()
DiagMatrix(values=tensor([1, 1]),
shape=(2, 2))
"""
return self.to(dtype=torch.long)
def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
) -> DiagMatrix:
"""Creates a diagonal matrix based on the diagonal values.
Parameters
----------
val : torch.Tensor
Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``
shape : tuple[int, int], optional
If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,
otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
Case1: 5-by-5 diagonal matrix with scaler values on the diagonal
>>> import torch
>>> val = torch.ones(5)
>>> dglsp.diag(val)
DiagMatrix(val=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5))
Case2: 5-by-10 diagonal matrix with scaler values on the diagonal
>>> val = torch.ones(5)
>>> dglsp.diag(val, shape=(5, 10))
DiagMatrix(val=tensor([1., 1., 1., 1., 1.]),
shape=(5, 10))
Case3: 5-by-5 diagonal matrix with vector values on the diagonal
>>> val = torch.randn(5, 3)
>>> D = dglsp.diag(val)
>>> D.shape
(5, 5)
>>> D.nnz
5
"""
assert (
val.dim() <= 2
), "The values of a DiagMatrix can only be scalars or vectors."
# NOTE(Mufei): this may not be needed if DiagMatrix is simple enough
return DiagMatrix(val, shape)
def identity(
shape: Tuple[int, int],
d: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> DiagMatrix:
r"""Creates a diagonal matrix with ones on the diagonal and zeros elsewhere.
Parameters
----------
shape : tuple[int, int]
Shape of the matrix.
d : int, optional
If None, the diagonal entries will be scaler 1. Otherwise, the diagonal
entries will be a 1-valued tensor of shape ``(d)``.
dtype : torch.dtype, optional
The data type of the matrix
device : torch.device, optional
The device of the matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
Case1: 3-by-3 matrix with scaler diagonal values
.. code::
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
>>> dglsp.identity(shape=(3, 3))
DiagMatrix(val=tensor([1., 1., 1.]),
shape=(3, 3))
Case2: 3-by-5 matrix with scaler diagonal values
.. code::
[[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]]
>>> dglsp.identity(shape=(3, 5))
DiagMatrix(val=tensor([1., 1., 1.]),
shape=(3, 5))
Case3: 3-by-3 matrix with vector diagonal values
>>> dglsp.identity(shape=(3, 3), d=2)
DiagMatrix(values=tensor([[1., 1.],
[1., 1.],
[1., 1.]]),
shape=(3, 3), val_size=(2,))
"""
len_val = min(shape)
if d is None:
val_shape = (len_val,)
else:
val_shape = (len_val, d)
val = torch.ones(val_shape, dtype=dtype, device=device)
return diag(val, shape)
def _diag_matrix_str(spmat: DiagMatrix) -> str:
"""Internal function for converting a diagonal matrix to string
representation.
"""
values_str = str(spmat.val)
meta_str = f"shape={spmat.shape}"
if spmat.val.dim() > 1:
val_size = tuple(spmat.val.shape[1:])
meta_str += f", val_size={val_size}"
prefix = f"{type(spmat).__name__}("
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
final_str = (
"values="
+ _add_indent(values_str, len("values="))
+ ",\n"
+ meta_str
+ ")"
)
final_str = prefix + _add_indent(final_str, len(prefix))
return final_str
...@@ -2,43 +2,26 @@ ...@@ -2,43 +2,26 @@
"""DGL elementwise operator module.""" """DGL elementwise operator module."""
from typing import Union from typing import Union
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
from .utils import Scalar from .utils import Scalar
__all__ = ["add", "sub", "mul", "div", "power"] __all__ = ["add", "sub", "mul", "div", "power"]
def add( def add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
A: Union[DiagMatrix, SparseMatrix], B: Union[DiagMatrix, SparseMatrix] r"""Elementwise addition for ``SparseMatrix``, equivalent to ``A + B``.
) -> Union[DiagMatrix, SparseMatrix]:
r"""Elementwise addition for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A + B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
| SparseMatrix | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : DiagMatrix or SparseMatrix A : SparseMatrix
Diagonal matrix or sparse matrix Sparse matrix
B : DiagMatrix or SparseMatrix B : SparseMatrix
Diagonal matrix or sparse matrix Sparse matrix
Returns Returns
------- -------
DiagMatrix or SparseMatrix SparseMatrix
Diagonal matrix if both :attr:`A` and :attr:`B` are diagonal matrices, Sparse matrix
sparse matrix otherwise
Examples Examples
-------- --------
...@@ -55,36 +38,20 @@ def add( ...@@ -55,36 +38,20 @@ def add(
return A + B return A + B
def sub( def sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
A: Union[DiagMatrix, SparseMatrix], B: Union[DiagMatrix, SparseMatrix] r"""Elementwise subtraction for ``SparseMatrix``, equivalent to ``A - B``.
) -> Union[DiagMatrix, SparseMatrix]:
r"""Elementwise subtraction for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A - B``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
| SparseMatrix | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : DiagMatrix or SparseMatrix A : SparseMatrix
Diagonal matrix or sparse matrix Sparse matrix
B : DiagMatrix or SparseMatrix B : SparseMatrix
Diagonal matrix or sparse matrix Sparse matrix
Returns Returns
------- -------
DiagMatrix or SparseMatrix SparseMatrix
Diagonal matrix if both :attr:`A` and :attr:`B` are diagonal matrices, Sparse matrix
sparse matrix otherwise
Examples Examples
-------- --------
...@@ -102,35 +69,25 @@ def sub( ...@@ -102,35 +69,25 @@ def sub(
def mul( def mul(
A: Union[SparseMatrix, DiagMatrix, Scalar], A: Union[SparseMatrix, Scalar], B: Union[SparseMatrix, Scalar]
B: Union[SparseMatrix, DiagMatrix, Scalar], ) -> SparseMatrix:
) -> Union[SparseMatrix, DiagMatrix]: r"""Elementwise multiplication for ``SparseMatrix``, equivalent to
r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``, ``A * B``.
equivalent to ``A * B``.
If both :attr:`A` and :attr:`B` are sparse matrices, both of them should be
The supported combinations are shown as follows. diagonal matrices.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| scalar | ✅ | ✅ | 🚫 |
+--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix or Scalar A : SparseMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value Sparse matrix or scalar value
B : SparseMatrix or DiagMatrix or Scalar B : SparseMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value Sparse matrix or scalar value
Returns Returns
------- -------
SparseMatrix or DiagMatrix SparseMatrix
Either sparse matrix or diagonal matrix Sparse matrix
Examples Examples
-------- --------
...@@ -145,59 +102,55 @@ def mul( ...@@ -145,59 +102,55 @@ def mul(
>>> D = dglsp.diag(torch.arange(1, 4)) >>> D = dglsp.diag(torch.arange(1, 4))
>>> dglsp.mul(D, 2) >>> dglsp.mul(D, 2)
DiagMatrix(val=tensor([2, 4, 6]), SparseMatrix(indices=tensor([[0, 1, 2],
shape=(3, 3)) [0, 1, 2]]),
values=tensor([2, 4, 6]),
shape=(3, 3), nnz=3)
>>> D = dglsp.diag(torch.arange(1, 4)) >>> D = dglsp.diag(torch.arange(1, 4))
>>> dglsp.mul(D, D) >>> dglsp.mul(D, D)
DiagMatrix(val=tensor([1, 4, 9]), SparseMatrix(indices=tensor([[0, 1, 2],
shape=(3, 3)) [0, 1, 2]]),
values=tensor([1, 4, 9]),
shape=(3, 3), nnz=3)
""" """
return A * B return A * B
def div( def div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
A: Union[SparseMatrix, DiagMatrix], B: Union[DiagMatrix, Scalar] r"""Elementwise division for ``SparseMatrix``, equivalent to ``A / B``.
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``.
The supported combinations are shown as follows. If both :attr:`A` and :attr:`B` are sparse matrices, both of them should be
diagonal matrices.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse or diagonal matrix Sparse matrix
B : DiagMatrix or Scalar B : SparseMatrix or Scalar
Diagonal matrix or scalar value Sparse matrix or scalar value
Returns Returns
------- -------
DiagMatrix SparseMatrix
Diagonal matrix Sparse matrix
Examples Examples
-------- --------
>>> A = dglsp.diag(torch.arange(1, 4)) >>> A = dglsp.diag(torch.arange(1, 4))
>>> B = dglsp.diag(torch.arange(10, 13)) >>> B = dglsp.diag(torch.arange(10, 13))
>>> dglsp.div(A, B) >>> dglsp.div(A, B)
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]), SparseMatrix(indices=tensor([[0, 1, 2],
shape=(3, 3)) [0, 1, 2]]),
values=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3), nnz=3)
>>> A = dglsp.diag(torch.arange(1, 4)) >>> A = dglsp.diag(torch.arange(1, 4))
>>> dglsp.div(A, 2) >>> dglsp.div(A, 2)
DiagMatrix(val=tensor([0.5000, 1.0000, 1.5000]), SparseMatrix(indices=tensor([[0, 1, 2],
shape=(3, 3)) [0, 1, 2]]),
values=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 3), nnz=3)
>>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]]) >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
>>> val = torch.tensor([1, 2, 3]) >>> val = torch.tensor([1, 2, 3])
...@@ -211,35 +164,21 @@ def div( ...@@ -211,35 +164,21 @@ def div(
return A / B return A / B
def power( def power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
A: Union[SparseMatrix, DiagMatrix], scalar: Scalar r"""Elementwise exponentiation ``SparseMatrix``, equivalent to
) -> Union[SparseMatrix, DiagMatrix]: ``A ** scalar``.
r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A ** scalar``.
The supported combinations are shown as follows.
+--------------+------------+--------------+--------+
| A \\ B | DiagMatrix | SparseMatrix | scalar |
+--------------+------------+--------------+--------+
| DiagMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse matrix or diagonal matrix Sparse matrix
scalar : Scalar scalar : Scalar
Exponent Exponent
Returns Returns
------- -------
SparseMatrix or DiagMatrix SparseMatrix
Sparse matrix or diagonal matrix, same type as A Sparse matrix
Examples Examples
-------- --------
...@@ -254,7 +193,9 @@ def power( ...@@ -254,7 +193,9 @@ def power(
>>> D = dglsp.diag(torch.arange(1, 4)) >>> D = dglsp.diag(torch.arange(1, 4))
>>> dglsp.power(D, 2) >>> dglsp.power(D, 2)
DiagMatrix(val=tensor([1, 4, 9]), SparseMatrix(indices=tensor([[0, 1, 2],
shape=(3, 3)) [0, 1, 2]]),
values=tensor([1, 4, 9]),
shape=(3, 3), nnz=3)
""" """
return A**scalar return A**scalar
"""DGL elementwise operators for diagonal matrix module."""
from typing import Union
from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix
from .utils import is_scalar, Scalar
def diag_add(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise addition
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix, same as D2
Examples
--------
>>> D1 = dglsp.diag(torch.arange(1, 4))
>>> D2 = dglsp.diag(torch.arange(10, 13))
>>> D1 + D2
DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3))
"""
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val + D2.val, D1.shape)
elif isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
)
D1 = D1.to_sparse()
return D1 + D2
# Python falls back to D2.__radd__(D1) then TypeError when NotImplemented
# is returned.
return NotImplemented
def diag_sub(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise subtraction
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix, same as D2
Examples
--------
>>> D1 = dglsp.diag(torch.arange(1, 4))
>>> D2 = dglsp.diag(torch.arange(10, 13))
>>> D1 - D2
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val - D2.val, D1.shape)
elif isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
)
D1 = D1.to_sparse()
return D1 - D2
# Python falls back to D2.__rsub__(D1) then TypeError when NotImplemented
# is returned.
return NotImplemented
def diag_rsub(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise subtraction in the opposite direction (``D2 - D1``)
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix, same as D2
Examples
--------
>>> D1 = dglsp.diag(torch.arange(1, 4))
>>> D2 = dglsp.diag(torch.arange(10, 13))
>>> D2 - D1
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
return -(D1 - D2)
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise multiplication
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D = dglsp.diag(torch.arange(1, 4))
>>> D * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3))
>>> 2 * D
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
"""
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val * D2.val, D1.shape)
elif is_scalar(D2):
return diag(D1.val * D2, D1.shape)
else:
# Python falls back to D2.__rmul__(D1) then TypeError when
# NotImplemented is returned.
return NotImplemented
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise division of a diagonal matrix by a diagonal matrix or a
scalar
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix,
division is only applied to the diagonal elements.
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = dglsp.diag(torch.arange(1, 4))
>>> D2 = dglsp.diag(torch.arange(10, 13))
>>> D1 / D2
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3))
>>> D1 / 2.5
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3))
"""
if isinstance(D2, DiagMatrix):
assert D1.shape == D2.shape, (
f"The shape of diagonal matrix D1 {D1.shape} and D2 {D2.shape} "
"must match."
)
return diag(D1.val / D2.val, D1.shape)
elif is_scalar(D2):
assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape)
else:
# Python falls back to D2.__rtruediv__(D1) then TypeError when
# NotImplemented is returned.
return NotImplemented
# pylint: disable=invalid-name
def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix:
"""Take the power of each nonzero element and return a diagonal matrix with
the result.
Parameters
----------
D : DiagMatrix
Diagonal matrix
scalar : Scalar
Exponent
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D = dglsp.diag(torch.arange(1, 4))
>>> D ** 2
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
return (
diag(D.val**scalar, D.shape) if is_scalar(scalar) else NotImplemented
)
DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_rsub
DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div
DiagMatrix.__pow__ = diag_power
"""DGL elementwise operators for sparse matrix module.""" """DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch import torch
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import diag, SparseMatrix, val_like
from .utils import is_scalar, Scalar from .utils import is_scalar, Scalar
...@@ -78,15 +80,17 @@ def sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: ...@@ -78,15 +80,17 @@ def sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
return spsp_add(A, -B) if isinstance(B, SparseMatrix) else NotImplemented return spsp_add(A, -B) if isinstance(B, SparseMatrix) else NotImplemented
def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix: def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""Elementwise multiplication """Elementwise multiplication
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must be
diagonal matrices.
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
First operand First operand
B : Scalar B : SparseMatrix or Scalar
Second operand Second operand
Returns Returns
...@@ -115,6 +119,12 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix: ...@@ -115,6 +119,12 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
""" """
if is_scalar(B): if is_scalar(B):
return val_like(A, A.val * B) return val_like(A, A.val * B)
if A.is_diag() and B.is_diag():
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise multiplication."
)
return diag(A.val * B.val, A.shape)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is # Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned. # returned.
# So this also handles the case of scalar * SparseMatrix since we set # So this also handles the case of scalar * SparseMatrix since we set
...@@ -122,14 +132,17 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix: ...@@ -122,14 +132,17 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
return NotImplemented return NotImplemented
def sp_div(A: SparseMatrix, B: Scalar) -> SparseMatrix: def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""Elementwise division """Elementwise division
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must be
diagonal matrices.
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
First operand First operand
B : Scalar B : SparseMatrix or Scalar
Second operand Second operand
Returns Returns
...@@ -150,6 +163,12 @@ def sp_div(A: SparseMatrix, B: Scalar) -> SparseMatrix: ...@@ -150,6 +163,12 @@ def sp_div(A: SparseMatrix, B: Scalar) -> SparseMatrix:
""" """
if is_scalar(B): if is_scalar(B):
return val_like(A, A.val / B) return val_like(A, A.val / B)
if A.is_diag() and B.is_diag():
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise division."
)
return diag(A.val / B.val, A.shape)
# Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented # Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented
# is returned. # is returned.
return NotImplemented return NotImplemented
......
...@@ -4,19 +4,17 @@ from typing import Union ...@@ -4,19 +4,17 @@ from typing import Union
import torch import torch
from .diag_matrix import diag, DiagMatrix from .sparse_matrix import SparseMatrix
from .sparse_matrix import SparseMatrix, val_like
__all__ = ["spmm", "bspmm", "spspmm", "matmul"] __all__ = ["spmm", "bspmm", "spspmm", "matmul"]
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def spmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
"""Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``. """Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``.
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse matrix of shape ``(L, M)`` with scalar values Sparse matrix of shape ``(L, M)`` with scalar values
X : torch.Tensor X : torch.Tensor
Dense matrix of shape ``(M, N)`` or ``(M)`` Dense matrix of shape ``(M, N)`` or ``(M)``
...@@ -30,7 +28,7 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -30,7 +28,7 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
-------- --------
>>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val = torch.randn(len(row)) >>> val = torch.randn(indices.shape[1])
>>> A = dglsp.spmatrix(indices, val) >>> A = dglsp.spmatrix(indices, val)
>>> X = torch.randn(2, 3) >>> X = torch.randn(2, 3)
>>> result = dglsp.spmm(A, X) >>> result = dglsp.spmm(A, X)
...@@ -40,25 +38,22 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -40,25 +38,22 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
torch.Size([2, 3]) torch.Size([2, 3])
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, SparseMatrix
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." ), f"Expect arg1 to be a SparseMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
X, torch.Tensor X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}." ), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
# The input is a DiagMatrix. Cast it to SparseMatrix
if not isinstance(A, SparseMatrix):
A = A.to_sparse()
return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X) return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)
def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def bspmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
"""Multiplies a sparse matrix by a dense matrix by batches, equivalent to """Multiplies a sparse matrix by a dense matrix by batches, equivalent to
``A @ X``. ``A @ X``.
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse matrix of shape ``(L, M)`` with vector values of length ``K`` Sparse matrix of shape ``(L, M)`` with vector values of length ``K``
X : torch.Tensor X : torch.Tensor
Dense matrix of shape ``(M, N, K)`` Dense matrix of shape ``(M, N, K)``
...@@ -82,115 +77,30 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -82,115 +77,30 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
torch.Size([3, 3, 2]) torch.Size([3, 3, 2])
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, SparseMatrix
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." ), f"Expect arg1 to be a SparseMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
X, torch.Tensor X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}." ), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
return spmm(A, X) return spmm(A, X)
def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix: def spspmm(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix.
Parameters
----------
A : DiagMatrix
Diagonal matrix of shape ``(L, M)``
B : DiagMatrix
Diagonal matrix of shape ``(M, N)``
Returns
-------
DiagMatrix
Diagonal matrix of shape ``(L, N)``
"""
M, N = A.shape
N, P = B.shape
common_diag_len = min(M, N, P)
new_diag_len = min(M, P)
diag_val = torch.zeros(new_diag_len)
diag_val[:common_diag_len] = (
A.val[:common_diag_len] * B.val[:common_diag_len]
)
return diag(diag_val.to(A.device), (M, P))
def _sparse_diag_mm(A, D):
"""Internal function for multiplying a sparse matrix by a diagonal matrix.
Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(L, M)``
D : DiagMatrix
Diagonal matrix of shape ``(M, N)``
Returns
-------
SparseMatrix
Sparse matrix of shape ``(L, N)``
"""
assert (
A.shape[1] == D.shape[0]
), f"The second dimension of SparseMatrix should be equal to the first \
dimension of DiagMatrix in matmul(SparseMatrix, DiagMatrix), but the \
shapes of SparseMatrix and DiagMatrix are {A.shape} and {D.shape} \
respectively."
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \
but got {D.shape}."
return val_like(A, D.val[A.col] * A.val)
def _diag_sparse_mm(D, A):
"""Internal function for multiplying a diagonal matrix by a sparse matrix.
Parameters
----------
D : DiagMatrix
Diagonal matrix of shape ``(L, M)``
A : SparseMatrix
Sparse matrix of shape ``(M, N)``
Returns
-------
SparseMatrix
Sparse matrix of shape ``(L, N)``
"""
assert (
D.shape[1] == A.shape[0]
), f"The second dimension of DiagMatrix should be equal to the first \
dimension of SparseMatrix in matmul(DiagMatrix, SparseMatrix), but the \
shapes of DiagMatrix and SparseMatrix are {D.shape} and {A.shape} \
respectively."
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \
but got {D.shape}."
return val_like(A, D.val[A.row] * A.val)
def spspmm(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``. """Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``.
The non-zero values of the two sparse matrices must be 1D. The non-zero values of the two sparse matrices must be 1D.
Parameters Parameters
---------- ----------
A : SparseMatrix or DiagMatrix A : SparseMatrix
Sparse matrix of shape ``(L, M)`` Sparse matrix of shape ``(L, M)``
B : SparseMatrix or DiagMatrix B : SparseMatrix
Sparse matrix of shape ``(M, N)`` Sparse matrix of shape ``(M, N)``
Returns Returns
------- -------
SparseMatrix or DiagMatrix SparseMatrix
Matrix of shape ``(L, N)``. It is a DiagMatrix object if both matrices Sparse matrix of shape ``(L, N)``.
are DiagMatrix objects, otherwise a SparseMatrix object.
Examples Examples
-------- --------
...@@ -208,68 +118,52 @@ def spspmm( ...@@ -208,68 +118,52 @@ def spspmm(
shape=(2, 3), nnz=5) shape=(2, 3), nnz=5)
""" """
assert isinstance( assert isinstance(
A, (SparseMatrix, DiagMatrix) A, SparseMatrix
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." ), f"Expect A1 to be a SparseMatrix object, got {type(A)}."
assert isinstance( assert isinstance(
B, (SparseMatrix, DiagMatrix) B, SparseMatrix
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}." ), f"Expect A2 to be a SparseMatrix object, got {type(B)}."
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B)
if isinstance(A, DiagMatrix):
return _diag_sparse_mm(A, B)
if isinstance(B, DiagMatrix):
return _sparse_diag_mm(A, B)
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix) torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix)
) )
def matmul( def matmul(
A: Union[torch.Tensor, SparseMatrix, DiagMatrix], A: Union[torch.Tensor, SparseMatrix], B: Union[torch.Tensor, SparseMatrix]
B: Union[torch.Tensor, SparseMatrix, DiagMatrix], ) -> Union[torch.Tensor, SparseMatrix]:
) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]: """Multiplies two dense/sparse matrices, equivalent to ``A @ B``.
"""Multiplies two dense/sparse/diagonal matrices, equivalent to ``A @ B``.
This function does not support the case where :attr:`A` is a \
The supported combinations are shown as follows. ``torch.Tensor`` and :attr:`B` is a ``SparseMatrix``.
+--------------+--------+------------+--------------+
| A \\ B | Tensor | DiagMatrix | SparseMatrix |
+--------------+--------+------------+--------------+
| Tensor | ✅ | 🚫 | 🚫 |
+--------------+--------+------------+--------------+
| SparseMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
| DiagMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
* If both matrices are torch.Tensor, it calls \ * If both matrices are torch.Tensor, it calls \
:func:`torch.matmul()`. The result is a dense matrix. :func:`torch.matmul()`. The result is a dense matrix.
* If both matrices are sparse or diagonal, it calls \ * If both matrices are sparse, it calls :func:`dgl.sparse.spspmm`. The \
:func:`dgl.sparse.spspmm`. The result is a sparse matrix. result is a sparse matrix.
* If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it \ * If :attr:`A` is sparse while :attr:`B` is dense, it calls \
calls :func:`dgl.sparse.spmm`. The result is a dense matrix. :func:`dgl.sparse.spmm`. The result is a dense matrix.
* The operator supports batched sparse-dense matrix multiplication. In \ * The operator supports batched sparse-dense matrix multiplication. In \
this case, the sparse or diagonal matrix :attr:`A` should have shape \ this case, the sparse matrix :attr:`A` should have shape ``(L, M)``, \
``(L, M)``, where the non-zero values have a batch dimension ``K``. \ where the non-zero values have a batch dimension ``K``. The dense \
The dense matrix :attr:`B` should have shape ``(M, N, K)``. The output \ matrix :attr:`B` should have shape ``(M, N, K)``. The output \
is a dense matrix of shape ``(L, N, K)``. is a dense matrix of shape ``(L, N, K)``.
* Sparse-sparse matrix multiplication does not support batched computation. * Sparse-sparse matrix multiplication does not support batched computation.
Parameters Parameters
---------- ----------
A : torch.Tensor, SparseMatrix or DiagMatrix A : torch.Tensor or SparseMatrix
The first matrix. The first matrix.
B : torch.Tensor, SparseMatrix, or DiagMatrix B : torch.Tensor or SparseMatrix
The second matrix. The second matrix.
Returns Returns
------- -------
torch.Tensor, SparseMatrix or DiagMatrix torch.Tensor or SparseMatrix
The result matrix The result matrix
Examples Examples
...@@ -289,7 +183,7 @@ def matmul( ...@@ -289,7 +183,7 @@ def matmul(
Multiplies a sparse matrix with a dense matrix. Multiplies a sparse matrix with a dense matrix.
>>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val = torch.randn(len(row)) >>> val = torch.randn(indices.shape[1])
>>> A = dglsp.spmatrix(indices, val) >>> A = dglsp.spmatrix(indices, val)
>>> X = torch.randn(2, 3) >>> X = torch.randn(2, 3)
>>> result = dglsp.matmul(A, X) >>> result = dglsp.matmul(A, X)
...@@ -301,10 +195,10 @@ def matmul( ...@@ -301,10 +195,10 @@ def matmul(
Multiplies a sparse matrix with a sparse matrix. Multiplies a sparse matrix with a sparse matrix.
>>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val1 = torch.ones(len(row1)) >>> val1 = torch.ones(indices1.shape[1])
>>> A = dglsp.spmatrix(indices1, val1) >>> A = dglsp.spmatrix(indices1, val1)
>>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]]) >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])
>>> val2 = torch.ones(len(row2)) >>> val2 = torch.ones(indices2.shape[1])
>>> B = dglsp.spmatrix(indices2, val2) >>> B = dglsp.spmatrix(indices2, val2)
>>> result = dglsp.matmul(A, B) >>> result = dglsp.matmul(A, B)
>>> type(result) >>> type(result)
...@@ -312,12 +206,11 @@ def matmul( ...@@ -312,12 +206,11 @@ def matmul(
>>> result.shape >>> result.shape
(2, 3) (2, 3)
""" """
assert isinstance(A, (torch.Tensor, SparseMatrix, DiagMatrix)), ( assert isinstance(
f"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object," A, (torch.Tensor, SparseMatrix)
f"got {type(A)}." ), f"Expect arg1 to be a torch.Tensor or SparseMatrix, got {type(A)}."
) assert isinstance(B, (torch.Tensor, SparseMatrix)), (
assert isinstance(B, (torch.Tensor, SparseMatrix, DiagMatrix)), ( f"Expect arg2 to be a torch Tensor or SparseMatrix"
f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f"object, got {type(B)}." f"object, got {type(B)}."
) )
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
...@@ -328,10 +221,7 @@ def matmul( ...@@ -328,10 +221,7 @@ def matmul(
) )
if isinstance(B, torch.Tensor): if isinstance(B, torch.Tensor):
return spmm(A, B) return spmm(A, B)
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B)
return spspmm(A, B) return spspmm(A, B)
SparseMatrix.__matmul__ = matmul SparseMatrix.__matmul__ = matmul
DiagMatrix.__matmul__ = matmul
...@@ -475,6 +475,10 @@ class SparseMatrix: ...@@ -475,6 +475,10 @@ class SparseMatrix:
""" """
return self.c_sparse_matrix.has_duplicate() return self.c_sparse_matrix.has_duplicate()
def is_diag(self):
"""Returns whether the sparse matrix is a diagonal matrix."""
return self.c_sparse_matrix.is_diag()
def spmatrix( def spmatrix(
indices: torch.Tensor, indices: torch.Tensor,
...@@ -859,6 +863,144 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix: ...@@ -859,6 +863,144 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val)) return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
) -> SparseMatrix:
"""Creates a sparse matrix based on the diagonal values.
Parameters
----------
val : torch.Tensor
Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``
shape : tuple[int, int], optional
If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,
otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: 5-by-5 diagonal matrix with scaler values on the diagonal
>>> import torch
>>> val = torch.ones(5)
>>> dglsp.diag(val)
SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5), nnz=5)
Case2: 5-by-10 diagonal matrix with scaler values on the diagonal
>>> val = torch.ones(5)
>>> dglsp.diag(val, shape=(5, 10))
SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 10), nnz=5)
Case3: 5-by-5 diagonal matrix with vector values on the diagonal
>>> val = torch.randn(5, 3)
>>> D = dglsp.diag(val)
>>> D.shape
(5, 5)
>>> D.nnz
5
"""
assert (
val.dim() <= 2
), "The values of a DiagMatrix can only be scalars or vectors."
len_val = len(val)
if shape is not None:
assert len_val == min(shape), (
f"Expect len(val) to be min(shape) for a diagonal matrix, got"
f"{len_val} for len(val) and {shape} for shape."
)
else:
shape = (len_val, len_val)
return SparseMatrix(torch.ops.dgl_sparse.from_diag(val, shape))
def identity(
shape: Tuple[int, int],
d: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> SparseMatrix:
r"""Creates a sparse matrix with ones on the diagonal and zeros elsewhere.
Parameters
----------
shape : tuple[int, int]
Shape of the matrix.
d : int, optional
If None, the diagonal entries will be scaler 1. Otherwise, the diagonal
entries will be a 1-valued tensor of shape ``(d)``.
dtype : torch.dtype, optional
The data type of the matrix
device : torch.device, optional
The device of the matrix
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: 3-by-3 matrix with scaler diagonal values
.. code::
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
>>> dglsp.identity(shape=(3, 3))
SparseMatrix(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1., 1., 1.]),
shape=(3, 3), nnz=3)
Case2: 3-by-5 matrix with scaler diagonal values
.. code::
[[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]]
>>> dglsp.identity(shape=(3, 5))
SparseMatrix(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
Case3: 3-by-3 matrix with vector diagonal values
>>> dglsp.identity(shape=(3, 3), d=2)
SparseMatrix(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([[1., 1.],
[1., 1.],
[1., 1.]]),
shape=(3, 3), nnz=3, val_size=(2,))
"""
len_val = min(shape)
if d is None:
val_shape = (len_val,)
else:
val_shape = (len_val, d)
val = torch.ones(val_shape, dtype=dtype, device=device)
return diag(val, shape)
def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix: def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:
"""Creates a sparse matrix from a torch sparse tensor, which can have coo, """Creates a sparse matrix from a torch sparse tensor, which can have coo,
csr, or csc layout. csr, or csc layout.
......
"""DGL unary operators for sparse matrix module.""" """DGL unary operators for sparse matrix module."""
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import diag, SparseMatrix, val_like
def neg(A: SparseMatrix) -> SparseMatrix: def neg(A: SparseMatrix) -> SparseMatrix:
...@@ -26,5 +26,36 @@ def neg(A: SparseMatrix) -> SparseMatrix: ...@@ -26,5 +26,36 @@ def neg(A: SparseMatrix) -> SparseMatrix:
return val_like(A, -A.val) return val_like(A, -A.val)
def inv(A: SparseMatrix) -> SparseMatrix:
"""Returns the inverse of the sparse matrix.
This function only supports square diagonal matrices with scalar nonzero
values.
Returns
-------
SparseMatrix
Inverse of the sparse matrix
Examples
--------
>>> val = torch.arange(1, 4).float()
>>> D = dglsp.diag(val)
>>> D.inv()
SparseMatrix(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1., 2., 3.]),
shape=(3, 3), nnz=3)
"""
num_rows, num_cols = A.shape
assert A.is_diag(), "Non-diagonal sparse matrix does not support inversion."
assert num_rows == num_cols, f"Expect a square matrix, got shape {A.shape}"
assert len(A.val.shape) == 1, "inv only supports 1D nonzero val"
return diag(1.0 / A.val, A.shape)
SparseMatrix.neg = neg SparseMatrix.neg = neg
SparseMatrix.__neg__ = neg SparseMatrix.__neg__ = neg
SparseMatrix.inv = inv
"""DGL unary operators for diagonal matrix module."""
# pylint: disable= invalid-name
from .diag_matrix import diag, DiagMatrix
def neg(D: DiagMatrix) -> DiagMatrix:
"""Returns a new diagonal matrix with the negation of the original nonzero
values, equivalent to ``-D``.
Returns
-------
DiagMatrix
Negation of the diagonal matrix
Examples
--------
>>> val = torch.arange(3).float()
>>> D = dglsp.diag(val)
>>> D = -D
DiagMatrix(val=tensor([-0., -1., -2.]),
shape=(3, 3))
"""
return diag(-D.val, D.shape)
def inv(D: DiagMatrix) -> DiagMatrix:
"""Returns the inverse of the diagonal matrix.
This function only supports square matrices with scalar nonzero values.
Returns
-------
DiagMatrix
Inverse of the diagonal matrix
Examples
--------
>>> val = torch.arange(1, 4).float()
>>> D = dglsp.diag(val)
>>> D = D.inv()
DiagMatrix(val=tensor([1.0000, 0.5000, 0.3333]),
shape=(3, 3))
"""
num_rows, num_cols = D.shape
assert num_rows == num_cols, f"Expect a square matrix, got shape {D.shape}"
assert len(D.val.shape) == 1, "inv only supports 1D nonzero val"
return diag(1.0 / D.val, D.shape)
DiagMatrix.neg = neg
DiagMatrix.__neg__ = neg
DiagMatrix.inv = inv
import sys
import unittest
import backend as F
import pytest
import torch
from dgl.sparse import diag, DiagMatrix, identity
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag(val_shape, mat_shape):
ctx = F.ctx()
# creation
val = torch.randn(val_shape).to(ctx)
mat = diag(val, mat_shape)
# val, shape attributes
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape
val = torch.randn(val_shape).to(ctx)
# nnz
assert mat.nnz == val.shape[0]
# dtype
assert mat.dtype == val.dtype
# device
assert mat.device == val.device
# as_sparse
sp_mat = mat.to_sparse()
# shape
assert tuple(sp_mat.shape) == mat_shape
# nnz
assert sp_mat.nnz == mat.nnz
# dtype
assert sp_mat.dtype == mat.dtype
# device
assert sp_mat.device == mat.device
# row, col, val
edge_index = torch.arange(len(val)).to(mat.device)
row, col = sp_mat.coo()
val = sp_mat.val
assert torch.allclose(row, edge_index)
assert torch.allclose(col, edge_index)
assert torch.allclose(val, val)
@pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)])
@pytest.mark.parametrize("d", [None, 2])
def test_identity(shape, d):
ctx = F.ctx()
# creation
mat = identity(shape, d)
# type
assert isinstance(mat, DiagMatrix)
# shape
assert mat.shape == shape
# val
len_val = min(shape)
if d is None:
val_shape = len_val
else:
val_shape = (len_val, d)
val = torch.ones(val_shape)
assert torch.allclose(val, mat.val)
def test_print():
ctx = F.ctx()
# basic
val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)
A = diag(val)
print(A)
# vector-shape non zero
val = torch.randn(3, 2).to(ctx)
A = diag(val)
print(A)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Device conversions don't need to be tested on CPU.",
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_to_device(device):
val = torch.randn(3)
mat_shape = (3, 4)
mat = diag(val, mat_shape)
target_val = mat.val.to(device)
mat2 = mat.to(device=device)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
mat2 = getattr(mat, device)()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
@pytest.mark.parametrize(
"dtype", [torch.float, torch.double, torch.int, torch.long]
)
def test_to_dtype(dtype):
val = torch.randn(3)
mat_shape = (3, 4)
mat = diag(val, mat_shape)
target_val = mat.val.to(dtype=dtype)
mat2 = mat.to(dtype=dtype)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
func_name = {
torch.float: "float",
torch.double: "double",
torch.int: "int",
torch.long: "long",
}
mat2 = getattr(mat, func_name[dtype])()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
ctx = F.ctx()
val = torch.randn(val_shape).to(ctx)
mat = diag(val, mat_shape).transpose()
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape[::-1]
import operator import operator
import sys
import backend as F import backend as F
...@@ -7,6 +6,65 @@ import dgl.sparse as dglsp ...@@ -7,6 +6,65 @@ import dgl.sparse as dglsp
import pytest import pytest
import torch import torch
from dgl.sparse import diag, power
@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
def test_diag_op_diag(opname):
op = getattr(operator, opname)
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
result = op(D1, D2)
assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
assert result.shape == D1.shape
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_diag_op_scalar(v_scalar):
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
# D * v
D2 = D1 * v_scalar
assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# v * D
D2 = v_scalar * D1
assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D / v
D2 = D1 / v_scalar
assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D ^ v
D1 = diag(torch.arange(1, 4).to(ctx))
D2 = D1**v_scalar
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# pow(D, v)
D2 = power(D1, v_scalar)
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
with pytest.raises(TypeError):
D1 + v_scalar
with pytest.raises(TypeError):
v_scalar + D1
with pytest.raises(TypeError):
D1 - v_scalar
with pytest.raises(TypeError):
v_scalar - D1
@pytest.mark.parametrize("val_shape", [(), (2,)]) @pytest.mark.parametrize("val_shape", [(), (2,)])
@pytest.mark.parametrize("opname", ["add", "sub"]) @pytest.mark.parametrize("opname", ["add", "sub"])
......
import operator
import sys
import backend as F
import pytest
import torch
from dgl.sparse import diag, power
# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)
def all_close_sparse(A, B):
assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))
assert torch.allclose(A.values(), B.values())
assert A.shape == B.shape
@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
def test_diag_op_diag(opname):
op = getattr(operator, opname)
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
result = op(D1, D2)
assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
assert result.shape == D1.shape
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_diag_op_scalar(v_scalar):
ctx = F.ctx()
shape = (3, 4)
D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
# D * v
D2 = D1 * v_scalar
assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# v * D
D2 = v_scalar * D1
assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D / v
D2 = D1 / v_scalar
assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# D ^ v
D1 = diag(torch.arange(1, 4).to(ctx))
D2 = D1**v_scalar
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
# pow(D, v)
D2 = power(D1, v_scalar)
assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
assert D1.shape == D2.shape
with pytest.raises(TypeError):
D1 + v_scalar
with pytest.raises(TypeError):
v_scalar + D1
with pytest.raises(TypeError):
D1 - v_scalar
with pytest.raises(TypeError):
v_scalar - D1
...@@ -157,7 +157,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz): ...@@ -157,7 +157,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
B.val.backward(grad) B.val.backward(grad)
torch_A = sparse_matrix_to_torch_sparse(A) torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D.to_sparse()) torch_D = sparse_matrix_to_torch_sparse(D)
torch_B = torch.sparse.mm(torch_A, torch_D) torch_B = torch.sparse.mm(torch_A, torch_D)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad) torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad) torch_B.backward(torch_B_grad)
...@@ -190,7 +190,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz): ...@@ -190,7 +190,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
B.val.backward(grad) B.val.backward(grad)
torch_A = sparse_matrix_to_torch_sparse(A) torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D.to_sparse()) torch_D = sparse_matrix_to_torch_sparse(D)
torch_B = torch.sparse.mm(torch_D, torch_A) torch_B = torch.sparse.mm(torch_D, torch_A)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad) torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad) torch_B.backward(torch_B_grad)
......
...@@ -6,10 +6,12 @@ import pytest ...@@ -6,10 +6,12 @@ import pytest
import torch import torch
from dgl.sparse import ( from dgl.sparse import (
diag,
from_coo, from_coo,
from_csc, from_csc,
from_csr, from_csr,
from_torch_sparse, from_torch_sparse,
identity,
to_torch_sparse_coo, to_torch_sparse_coo,
to_torch_sparse_csc, to_torch_sparse_csc,
to_torch_sparse_csr, to_torch_sparse_csr,
...@@ -606,3 +608,69 @@ def test_torch_sparse_csc_conversion(indptr, indices, shape): ...@@ -606,3 +608,69 @@ def test_torch_sparse_csc_conversion(indptr, indices, shape):
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc) _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
torch_sparse_csc = to_torch_sparse_csc(spmat) torch_sparse_csc = to_torch_sparse_csc(spmat)
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc) _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
### Diag foramt related tests ###
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag(val_shape, mat_shape):
ctx = F.ctx()
# creation
val = torch.randn(val_shape).to(ctx)
mat = diag(val, mat_shape)
# val, shape attributes
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape
val = torch.randn(val_shape).to(ctx)
# nnz
assert mat.nnz == val.shape[0]
# dtype
assert mat.dtype == val.dtype
# device
assert mat.device == val.device
# row, col, val
edge_index = torch.arange(len(val)).to(mat.device)
row, col = mat.coo()
val = mat.val
assert torch.allclose(row, edge_index)
assert torch.allclose(col, edge_index)
assert torch.allclose(val, val)
@pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)])
@pytest.mark.parametrize("d", [None, 2])
def test_identity(shape, d):
ctx = F.ctx()
# creation
mat = identity(shape, d)
# shape
assert mat.shape == shape
# val
len_val = min(shape)
if d is None:
val_shape = len_val
else:
val_shape = (len_val, d)
val = torch.ones(val_shape)
assert torch.allclose(val, mat.val)
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
ctx = F.ctx()
val = torch.randn(val_shape).to(ctx)
mat = diag(val, mat_shape).transpose()
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape[::-1]
...@@ -3,10 +3,24 @@ import sys ...@@ -3,10 +3,24 @@ import sys
import backend as F import backend as F
import torch import torch
from dgl.sparse import diag from dgl.sparse import diag, spmatrix
def test_neg(): def test_neg():
ctx = F.ctx()
row = torch.tensor([1, 1, 3]).to(ctx)
col = torch.tensor([1, 2, 3]).to(ctx)
val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)
A = spmatrix(torch.stack([row, col]), val)
neg_A = -A
assert A.shape == neg_A.shape
assert A.nnz == neg_A.nnz
assert torch.allclose(-A.val, neg_A.val)
assert torch.allclose(torch.stack(A.coo()), torch.stack(neg_A.coo()))
assert A.val.device == neg_A.val.device
def test_diag_neg():
ctx = F.ctx() ctx = F.ctx()
val = torch.arange(3).float().to(ctx) val = torch.arange(3).float().to(ctx)
D = diag(val) D = diag(val)
...@@ -16,7 +30,7 @@ def test_neg(): ...@@ -16,7 +30,7 @@ def test_neg():
assert D.val.device == neg_D.val.device assert D.val.device == neg_D.val.device
def test_inv(): def test_diag_inv():
ctx = F.ctx() ctx = F.ctx()
val = torch.arange(1, 4).float().to(ctx) val = torch.arange(1, 4).float().to(ctx)
D = diag(val) D = diag(val)
......
import sys
import backend as F
import torch
from dgl.sparse import from_coo
def test_neg():
ctx = F.ctx()
row = torch.tensor([1, 1, 3]).to(ctx)
col = torch.tensor([1, 2, 3]).to(ctx)
val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)
A = from_coo(row, col, val)
neg_A = -A
assert A.shape == neg_A.shape
assert A.nnz == neg_A.nnz
assert torch.allclose(-A.val, neg_A.val)
assert torch.allclose(torch.stack(A.coo()), torch.stack(neg_A.coo()))
assert A.val.device == neg_A.val.device
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment