"docs/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "6bef412a46aa115574e0e5feeca890c20b59f0c8"
Unverified Commit 977b1ba4 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Migration of Unary, Softmax, Matmul Ops (#4616)



* Update

* lint

* Fix

* Update

* Update

* Update

* update

* Fix

* Update

* Fix

* Fix

* Fix

* CI

* Update

* Update

* Update

* update test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
parent b3ae13ab
...@@ -14,7 +14,7 @@ Sparse matrix class ...@@ -14,7 +14,7 @@ Sparse matrix class
.. autoclass:: SparseMatrix .. autoclass:: SparseMatrix
:members: shape, nnz, dtype, device, row, col, val, __call__, indices, coo, csr, csc, dense, t, T, transpose, :members: shape, nnz, dtype, device, row, col, val, __call__, indices, coo, csr, csc, dense, t, T, transpose,
reduce, sum, smax, smin, smean reduce, sum, smax, smin, smean, __neg__, inv, softmax, __matmul__
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
...@@ -29,10 +29,22 @@ Diagonal matrix class ...@@ -29,10 +29,22 @@ Diagonal matrix class
.. autoclass:: DiagMatrix .. autoclass:: DiagMatrix
:members: val, shape, __call__, nnz, dtype, device, as_sparse, t, T, transpose, :members: val, shape, __call__, nnz, dtype, device, as_sparse, t, T, transpose,
reduce, sum, smax, smin, smean reduce, sum, smax, smin, smean, __neg__, inv, softmax, __matmul__
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
diag diag
identity identity
Operators
---------
.. currentmodule:: dgl.mock_sparse
.. autosummary::
:toctree: ../../generated/
spmm
spspmm
bspmm
bspspmm
...@@ -3,3 +3,6 @@ from .diag_matrix import * ...@@ -3,3 +3,6 @@ from .diag_matrix import *
from .sp_matrix import * from .sp_matrix import *
from .elementwise_op_sp import * from .elementwise_op_sp import *
from .reduction import * # pylint: disable=W0622 from .reduction import * # pylint: disable=W0622
from .unary_diag import *
from .unary_sp import *
from .matmul import *
"""Matmul ops for SparseMatrix"""
# pylint: disable=invalid-name
from typing import Union, List
import torch
from .diag_matrix import DiagMatrix, diag
from .sp_matrix import SparseMatrix, create_from_coo
__all__ = [
'spmm',
'spspmm',
'bspmm',
'bspspmm'
]
def _sparse_dense_mm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
"""Internal function for multiplying a sparse matrix by a dense matrix
Parameters
----------
A : SparseMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
X : torch.Tensor
Dense tensor of shape (M, F) or (M)
Returns
-------
torch.Tensor
The result of multiplication
"""
return torch.matmul(A.adj, X)
def _sparse_sparse_mm(A1: SparseMatrix, A2: SparseMatrix) -> SparseMatrix:
"""Internal function for multiplying a sparse matrix by a sparse matrix
Parameters
----------
A1 : SparseMatrix
Sparse matrix of shape (N, M) with values of shape (nnz1)
A2 : SparseMatrix
Sparse matrix of shape (M, P) with values of shape (nnz2)
Returns
-------
SparseMatrix
The result of multiplication
"""
result = torch.sparse.mm(A1.adj, A2.adj).coalesce()
row, col = result.indices()
return create_from_coo(row=row,
col=col,
val=result.values(),
shape=result.size())
def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix
Parameters
----------
A1 : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
DiagMatrix
The result of multiplication.
"""
M, N = A1.shape
N, P = A2.shape
common_diag_len = min(M, N, P)
new_diag_len = min(M, P)
diag_val = torch.zeros(new_diag_len)
diag_val[:common_diag_len] = A1.val[:common_diag_len] * A2.val[:common_diag_len]
return diag(diag_val.to(A1.device), (M, P))
def _unbatch_tensor(A: Union[torch.Tensor, SparseMatrix, DiagMatrix])\
-> Union[List[torch.Tensor], List[SparseMatrix], List[DiagMatrix]]:
"""Internal function for unbatching a tensor, sparse matrix, or diagonal matrix
Parameters
----------
A : torch.Tensor or SparseMatrix, or DiagMatrix
Batched matrix/tensor
Returns
-------
list[torch.Tensor] or list[SparseMatrix] or list[DiagMatrix]
Unbatched matrices/tensors
"""
if isinstance(A, torch.Tensor):
return [A[..., i] for i in range(A.shape[-1])]
elif isinstance(A, SparseMatrix):
return [
create_from_coo(row=A.row, col=A.col, val=A.val[:, i], shape=A.shape)
for i in range(A.val.shape[-1])]
else:
return [diag(A.val[:, i], A.shape) for i in range(A.val.shape[-1])]
def _batch_tensor(A_list: Union[List[torch.Tensor], List[SparseMatrix], List[DiagMatrix]])\
-> Union[torch.Tensor, SparseMatrix, DiagMatrix]:
"""Internal function for batching a list of tensors, sparse matrices, or diagonal matrices
Parameters
----------
A_list : list[torch.Tensor] or list[SparseMatrix] or list[DiagMatrix]
A list of tensors, sparse matrices, or diagonal matrices
Returns
-------
torch.Tensor or SparseMatrix, or DiagMatrix
Batched matrix/tensor
"""
A = A_list[0]
if isinstance(A, torch.Tensor):
return torch.stack(A_list, dim=-1)
elif isinstance(A, SparseMatrix):
return create_from_coo(
row=A.row, col=A.col,
val=torch.stack([A_list[i].val for i in range(len(A_list))], dim=-1), shape=A.shape)
else:
return diag(
val=torch.stack([A_list[i].val for i in range(len(A_list))], dim=-1), shape=A.shape)
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
X : torch.Tensor
Dense tensor of shape (M, F) or (M)
Returns
-------
torch.Tensor
The result of multiplication
Examples
--------
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> val = torch.randn(len(row))
>>> A = create_from_coo(row, col, val)
>>> X = torch.randn(2, 3)
>>> result = A @ X
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3])
"""
assert isinstance(A, (SparseMatrix, DiagMatrix)), \
f'Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}'
assert isinstance(X, torch.Tensor), f'Expect arg2 to be a torch.Tensor, got {type(X)}'
assert A.shape[1] == X.shape[0], \
f'Expect arg1.shape[1] == arg2.shape[0], got {A.shape[1]} and {X.shape[0]}'
val_dim = len(A.val.shape)
assert val_dim == 1, f'Expect arg1.val to be a 1D tensor, got {val_dim}D'
val_dim = len(X.shape)
assert val_dim <= 2, f'Expect arg2 to be a 1D/2D tensor, got {val_dim}D'
if isinstance(A, SparseMatrix):
return _sparse_dense_mm(A, X)
else:
return _sparse_dense_mm(A.as_sparse(), X)
def spspmm(A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix])\
-> Union[SparseMatrix, DiagMatrix]:
"""Multiply a sparse matrix by a sparse matrix
Parameters
----------
A1 : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
A2 : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz)
Returns
-------
SparseMatrix or DiagMatrix
The result of multiplication. It is a DiagMatrix object if both matrices are
DiagMatrix objects. It is a SparseMatrix object otherwise.
Examples
--------
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A1 = create_from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> A2 = create_from_coo(row2, col2, val2)
>>> result = A1 @ A2
>>> print(result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(2, 3), nnz=5)
"""
assert isinstance(A1, (SparseMatrix, DiagMatrix)), \
f'Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}'
assert isinstance(A2, (SparseMatrix, DiagMatrix)), \
f'Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}'
assert A1.shape[1] == A2.shape[0], \
f'Expect A1.shape[1] == A2.shape[0], got {A1.shape[1]} and {A2.shape[0]}'
val_dim = len(A1.val.shape)
assert val_dim == 1, f'Expect A1.val to be a 1D tensor, got {val_dim}D'
val_dim = len(A2.val.shape)
assert val_dim == 1, f'Expect A2.val to be a 1D tensor, got {val_dim}D'
if isinstance(A1, SparseMatrix):
if isinstance(A2, SparseMatrix):
return _sparse_sparse_mm(A1, A2)
else:
return _sparse_sparse_mm(A1, A2.as_sparse())
else:
if isinstance(A2, SparseMatrix):
return _sparse_sparse_mm(A1.as_sparse(), A2)
else:
return _diag_diag_mm(A1, A2)
def mm_sp(A1: SparseMatrix, A2: Union[torch.Tensor, SparseMatrix, DiagMatrix])\
-> Union[torch.Tensor, SparseMatrix]:
"""Internal function for multiplying a sparse matrix by a dense/sparse/diagonal matrix
Parameters
----------
A1 : SparseMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix,
it should have values of shape (nnz2)
Returns
-------
torch.Tensor or SparseMatrix
The result of multiplication.
* It is a dense torch tensor if :attr:`A2` is so.
* It is a SparseMatrix object otherwise.
Examples
--------
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> val = torch.randn(len(row))
>>> A1 = create_from_coo(row, col, val)
>>> A2 = torch.randn(2, 3)
>>> result = A1 @ A2
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3])
"""
assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), \
f'Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix object, got {type(A2)}'
if isinstance(A2, torch.Tensor):
return spmm(A1, A2)
else:
return spspmm(A1, A2)
def mm_diag(A1: DiagMatrix, A2: Union[torch.Tensor, SparseMatrix, DiagMatrix])\
-> Union[torch.Tensor, SparseMatrix, DiagMatrix]:
"""Multiply a diagonal matrix by a dense/sparse/diagonal matrix
Parameters
----------
A1 : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix,
it should have values of shape (nnz2).
Returns
-------
torch.Tensor or DiagMatrix or SparseMatrix
The result of multiplication.
* It is a dense torch tensor if :attr:`A2` is so.
* It is a DiagMatrix object if :attr:`A2` is so.
* It is a SparseMatrix object otherwise.
Examples
--------
>>> val = torch.randn(3)
>>> A1 = diag(val)
>>> A2 = torch.randn(3, 2)
>>> result = A1 @ A2
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([3, 2])
"""
assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), \
f'Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix object, got {type(A2)}'
if isinstance(A2, torch.Tensor):
return spmm(A1, A2)
else:
return spspmm(A1, A2)
SparseMatrix.__matmul__ = mm_sp
DiagMatrix.__matmul__ = mm_diag
def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor)\
-> torch.Tensor:
"""Batched multiplication of a sparse matrix by a dense matrix,
with the last dimension being the batch dimension
We may consider a SparseMatrix/DiagMatrix with shape (N, M) and values of shape (nnz1, H)
to be a tensor of shape (N, M, H). The result is then obtained by
.. code::
result = []
for i in range(H):
# If X is a 2D torch Tensor, then this will be X[:, i]
result.append(A[:, :, i] @ X[:, :, i])
result = torch.stack(result, dim=-1)
Parameters
----------
A : SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1, H)
X : torch.Tensor
Matrix of shape (M, P)
Returns
-------
torch.Tensor
The result of multiplication
Examples
--------
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> H = 4
>>> val = torch.randn(len(row), H)
>>> A = create_from_coo(row, col, val)
>>> X = torch.randn(2, 3, H)
>>> result = bspmm(A, X)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3, 4])
"""
assert isinstance(A, (SparseMatrix, DiagMatrix)), \
f'Expect A to be a SparseMatrix or DiagMatrix object, got {type(A)}'
assert isinstance(X, torch.Tensor), f'Expect X to be a torch Tensor, got {type(X)}'
val_dim = len(A.val.shape)
assert val_dim == 2, f'Expect A.val to be a 2D tensor, got {val_dim}D'
H1 = A.val.shape[-1]
val_dim = len(X.shape)
assert val_dim in [2, 3], f'Expect X to be a 2D/3D tensor, got {val_dim}D'
H2 = X.shape[-1]
assert H1 == H2, f'Expect A.val.shape[-1] == X.shape[-1], got {H1} and {H2}'
A_unbatched = _unbatch_tensor(A)
X_unbatched = _unbatch_tensor(X)
results = [spmm(A_unbatched[i], X_unbatched[i]) for i in range(H1)]
return _batch_tensor(results)
def bspspmm(A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix])\
-> Union[SparseMatrix, DiagMatrix]:
"""Batched multiplication of a sparse matrix by a sparse matrix,
with the last dimension being the batch dimension
We may consider a SparseMatrix/DiagMatrix with shape (N, M) and values of shape (nnz1, H)
to be a tensor of shape (N, M, H). The result is then obtained by
.. code::
result = []
for i in range(H):
# If A2 is a 2D torch Tensor, then this will be A2[:, i]
result.append(A1[:, :, i] @ A2[:, :, i])
result = torch.stack(result, dim=-1)
Parameters
----------
A1 : SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1, H)
A2 : SparseMatrix or DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2, H)
Returns
-------
SparseMatrix or DiagMatrix
The result of multiplication
* It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so.
* It is a SparseMatrix object otherwise.
Examples
--------
>>> H = 4
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1), H)
>>> A1 = create_from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2), H)
>>> A2 = create_from_coo(row2, col2, val2)
>>> sparse_result = bspspmm(A1, A2)
>>> print(sparse_result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
values=tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]),
shape=(2, 3), nnz=5)
"""
assert isinstance(A1, (SparseMatrix, DiagMatrix)), \
f'Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}'
assert isinstance(A2, (SparseMatrix, DiagMatrix)), \
f'Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}'
val_dim = len(A1.val.shape)
assert val_dim == 2, f'Expect A1.val to be a 2D tensor, got {val_dim}D'
H1 = A1.val.shape[-1]
val_dim = len(A2.val.shape)
assert val_dim == 2, f'Expect A2.val to be a 2D tensor, got {val_dim}D'
H2 = A2.val.shape[-1]
assert H1 == H2, f'Expect A1.val.shape[-1] == A2.val.shape[-1], got {H1} and {H2}'
A1_unbatched = _unbatch_tensor(A1)
A2_unbatched = _unbatch_tensor(A2)
results = [spspmm(A1_unbatched[i], A2_unbatched[i]) for i in range(H1)]
return _batch_tensor(results)
...@@ -53,12 +53,8 @@ class SparseMatrix: ...@@ -53,12 +53,8 @@ class SparseMatrix:
val: Optional[torch.Tensor] = None, val: Optional[torch.Tensor] = None,
shape : Optional[Tuple[int, int]] = None shape : Optional[Tuple[int, int]] = None
): ):
self._row = row
self._col = col
if val is None: if val is None:
val = torch.ones(row.shape[0]) val = torch.ones(row.shape[0])
self._val = val
i = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0) i = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
if shape is None: if shape is None:
self.adj = torch.sparse_coo_tensor(i, val).coalesce() self.adj = torch.sparse_coo_tensor(i, val).coalesce()
...@@ -152,7 +148,6 @@ class SparseMatrix: ...@@ -152,7 +148,6 @@ class SparseMatrix:
def val(self, x) -> torch.tensor: def val(self, x) -> torch.tensor:
"""Set the values of the nonzero elements.""" """Set the values of the nonzero elements."""
assert len(x) == self.nnz assert len(x) == self.nnz
self._val = x
if len(x.shape) == 1: if len(x.shape) == 1:
shape = self.shape shape = self.shape
else: else:
......
"""Unary ops for DiagMatrix"""
# pylint: disable=invalid-name
import torch
from .diag_matrix import DiagMatrix, diag
def neg(D: DiagMatrix) -> DiagMatrix:
"""Return a new diagonal matrix with negative elements.
Returns
-------
DiagMatrix
Negative of the diagonal matrix.
Examples
--------
>>> val = torch.arange(3).float()
>>> mat = diag(val)
>>> mat = -mat
>>> print(mat)
DiagMatrix(val=tensor([-0., -1., -2.]),
shape=(3, 3))
"""
return diag(-D.val, D.shape)
def inv(D: DiagMatrix) -> DiagMatrix:
"""Compute the inverse.
Only square matrices with values of shape (nnz) are supported.
Returns
-------
DiagMatrix
Inverse of the diagonal matrix.
Examples
--------
>>> val = torch.arange(1, 4).float()
>>> mat = diag(val)
>>> mat = mat.inv()
>>> print(mat)
DiagMatrix(val=tensor([1.0000, 0.5000, 0.3333]),
shape=(3, 3))
"""
num_rows, num_cols = D.shape
assert num_rows == num_cols, f'Expect a square matrix, got shape {D.shape}'
assert len(D.val.shape) == 1, 'inv only supports matrices with 1D val'
return diag(1. / D.val, D.shape)
def softmax(D: DiagMatrix) -> DiagMatrix:
"""Apply row-wise softmax to the nonzero entries of the diagonal matrix.
The result will be a diagonal matrix with one-valued diagonal.
Parameters
----------
D : DiagMatrix
The input diagonal matrix
Returns
-------
DiagMatrix
The result.
Examples
--------
Case1: matrix with values of shape (nnz)
>>> val = torch.randn(3)
>>> D = diag(val)
>>> result = D.softmax()
>>> result.val
tensor([1., 1., 1.])
>>> result.shape
(3, 3)
Case2: matrix with values of shape (nnz, D)
>>> val = torch.randn(3, 4)
>>> D = diag(val)
>>> result = D.softmax()
>>> result.val
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> result.shape
(3, 3)
"""
return diag(torch.ones_like(D.val), D.shape)
DiagMatrix.__neg__ = neg
DiagMatrix.inv = inv
DiagMatrix.softmax = softmax
"""Unary ops for SparseMatrix"""
# pylint: disable=invalid-name
import numpy as np
import torch
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import inv as scipy_inv
from .sp_matrix import SparseMatrix, create_from_coo
from ..convert import graph
from ..ops.edge_softmax import edge_softmax
def neg(A: SparseMatrix) -> SparseMatrix:
"""Return a new sparse matrix with negative elements.
Returns
-------
SparseMatrix
Negative of the sparse matrix.
Examples
--------
>>> row = torch.tensor([1, 1, 3])
>>> col = torch.tensor([1, 2, 3])
>>> val = torch.tensor([1., 1., 2.])
>>> A = create_from_coo(row, col, val)
>>> A = -A
>>> print(A)
SparseMatrix(indices=tensor([[1, 1, 3],
[1, 2, 3]]),
values=tensor([-1., -1., -2.]),
shape=(4, 4), nnz=3)
"""
return create_from_coo(row=A.row,
col=A.col,
val=-A.val,
shape=A.shape)
def inv(A: SparseMatrix) -> SparseMatrix:
"""Compute the inverse.
Only non-singular square matrices with values of shape (nnz) are supported.
Returns
-------
SparseMatrix
Inverse of the sparse matrix.
Examples
--------
[[1, 0],
[1, 2]]
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([0, 0, 1])
>>> val = torch.tensor([1, 1, 2])
>>> A = create_from_coo(row, col, val)
[[1, 0 ],
[-0.5, 0.5]]
>>> A_inv = A.inv()
>>> print(A_inv)
SparseMatrix(indices=tensor([[0, 1, 1],
[0, 0, 1]]),
values=tensor([1.0000, -0.5000, 0.5000]),
shape=(2, 2), nnz=3)
"""
num_rows, num_cols = A.shape
assert num_rows == num_cols, 'Expect a square matrix, got shape {}'.format(A.shape)
assert len(A.val.shape) == 1, 'inv only supports matrices with 1D val'
val = A.val.cpu().numpy()
row = A.row.cpu().numpy()
col = A.col.cpu().numpy()
# The computation is more efficient with CSC format.
mat = coo_matrix((val, (row, col)), dtype=val.dtype).tocsc()
mat_inv = scipy_inv(mat)
row, col = mat_inv.nonzero()
val = mat_inv[row, col]
val = np.asarray(val).squeeze(0)
dev = A.device
return create_from_coo(row=torch.from_numpy(row).to(dev),
col=torch.from_numpy(col).to(dev),
val=torch.from_numpy(val).to(dev),
shape=A.shape)
def softmax(A: SparseMatrix) -> SparseMatrix:
"""Apply row-wise softmax to the nonzero entries of the sparse matrix.
If :attr:`A.val` takes shape :attr:`(nnz, D)`, then the output matrix
:attr:`A'` and :attr:`A'.val` take the same shape as :attr:`A` and :attr:`A.val`.
:attr:`A'.val[:, i]` is calculated based on :attr:`A.val[:, i]`.
Parameters
----------
A : SparseMatrix
The input sparse matrix
Returns
-------
SparseMatrix
The result, whose shape is the same as :attr:`A`
Examples
--------
Case1: matrix with values of shape (nnz)
>>> row = torch.tensor([0, 0, 1, 2])
>>> col = torch.tensor([1, 2, 2, 0])
>>> val = torch.ones(len(row))
>>> A = create_from_coo(row, col, val)
>>> result = A.softmax()
>>> result.val
tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> result.shape
(3, 3)
Case2: matrix with values of shape (nnz, D)
>>> row = torch.tensor([0, 0, 1, 2])
>>> col = torch.tensor([1, 2, 2, 0])
>>> val = torch.ones(len(row), 2)
>>> A = create_from_coo(row, col, val)
>>> result = A.softmax()
>>> result.val
tensor([[0.5000, 0.5000],
[0.5000, 0.5000],
[1.0000, 1.0000],
[1.0000, 1.0000]])
>>> result.shape
(3, 3)
"""
g = graph((A.col, A.row))
return create_from_coo(A.row,
A.col,
edge_softmax(g, A.val),
A.shape)
SparseMatrix.__neg__ = neg
SparseMatrix.inv = inv
SparseMatrix.softmax = softmax
import torch
import backend as F
from dgl.mock_sparse import create_from_coo, diag, bspmm, bspspmm
def get_adj(A):
edge_index = torch.cat((A.row.unsqueeze(0), A.col.unsqueeze(0)), 0)
shape = A.shape
if len(A.val.shape) > 1:
shape += (A.val.shape[-1],)
return torch.sparse_coo_tensor(edge_index, A.val, shape).coalesce().to_dense()
def test_sparse_dense_mm():
dev = F.ctx()
# A: shape (N, M), X: shape (M, F)
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val = torch.randn(len(row)).to(dev)
A = create_from_coo(row, col, val)
X = torch.randn(2, 3).to(dev)
sparse_result = A @ X
adj = get_adj(A)
dense_result = adj @ X
assert torch.allclose(sparse_result, dense_result)
# X: shape (M)
X = torch.randn(2).to(dev)
sparse_result = A @ X
dense_result = adj @ X
assert torch.allclose(sparse_result, dense_result)
def test_sparse_sparse_mm():
dev = F.ctx()
row1 = torch.tensor([0, 1, 1]).to(dev)
col1 = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row1)).to(dev)
A1 = create_from_coo(row1, col1, val1)
row2 = torch.tensor([0, 1, 1]).to(dev)
col2 = torch.tensor([0, 2, 1]).to(dev)
val2 = torch.randn(len(row2)).to(dev)
A2 = create_from_coo(row2, col2, val2)
sparse_result = get_adj(A1 @ A2)
dense_result = get_adj(A1) @ get_adj(A2)
assert torch.allclose(sparse_result, dense_result)
def test_sparse_diag_mm():
dev = F.ctx()
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row)).to(dev)
A = create_from_coo(row, col, val1)
val2 = torch.randn(2).to(dev)
D = diag(val2, (2, 3))
M1 = get_adj(A @ D)
M2 = get_adj(A @ D.as_sparse())
assert torch.allclose(M1, M2)
def test_diag_dense_mm():
dev = F.ctx()
# D: shape (N, N), X: shape (N, F)
val = torch.randn(3).to(dev)
D = diag(val)
X = torch.randn(3, 2).to(dev)
sparse_result = D @ X
dense_result = get_adj(D.as_sparse()) @ X
assert torch.allclose(sparse_result, dense_result)
# D: shape (N, M), N > M, X: shape (M, F)
val = torch.randn(3).to(dev)
D = diag(val, shape=(4, 3))
sparse_result = D @ X
dense_result = get_adj(D.as_sparse()) @ X
assert torch.allclose(sparse_result, dense_result)
# D: shape (N, M), N < M, X: shape (M, F)
val = torch.randn(2).to(dev)
D = diag(val, shape=(2, 3))
sparse_result = D @ X
dense_result = get_adj(D.as_sparse()) @ X
assert torch.allclose(sparse_result, dense_result)
# D: shape (N, M), X: shape (M)
val = torch.randn(3).to(dev)
D = diag(val)
X = torch.randn(3).to(dev)
sparse_result = D @ X
dense_result = get_adj(D.as_sparse()) @ X
assert torch.allclose(sparse_result, dense_result)
def test_diag_sparse_mm():
dev = F.ctx()
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row)).to(dev)
A = create_from_coo(row, col, val1)
val2 = torch.randn(2).to(dev)
D = diag(val2, (3, 2))
M1 = get_adj(D @ A)
M2 = get_adj(D.as_sparse() @ A)
assert torch.allclose(M1, M2)
def test_diag_diag_mm():
dev = F.ctx()
# D1, D2: shape (N, N)
val1 = torch.randn(3).to(dev)
D1 = diag(val1)
val2 = torch.randn(3).to(dev)
D2 = diag(val2)
sparse_result = D1 @ D2
assert torch.allclose(sparse_result.val, D1.val * D2.val)
# D1: shape (N, M), D2: shape (M, P)
N = 3
M = 4
P = 2
val1 = torch.randn(N).to(dev)
D1 = diag(val1, (N, M))
val2 = torch.randn(P).to(dev)
D2 = diag(val2, (M, P))
M1 = get_adj((D1 @ D2).as_sparse())
M2 = get_adj(D1.as_sparse() @ D2.as_sparse())
assert torch.allclose(M1, M2)
def test_batch_sparse_dense_mm():
dev = F.ctx()
# A: shape (N, M), val shape (nnz, H)
# X: shape (M, F, H)
H = 4
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val = torch.randn(len(row), H).to(dev)
A = create_from_coo(row, col, val)
X = torch.randn(2, 3, H).to(dev)
sparse_result = bspmm(A, X)
dense_A = get_adj(A)
dense_result = torch.stack([
dense_A[:, :, i] @ X[..., i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
# X: shape (M, H)
X = torch.randn(2, H).to(dev)
sparse_result = bspmm(A, X)
dense_A = get_adj(A)
dense_result = torch.stack([
dense_A[:, :, i] @ X[..., i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
def test_batch_sparse_sparse_mm():
H = 4
dev = F.ctx()
row1 = torch.tensor([0, 1, 1]).to(dev)
col1 = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row1), H).to(dev)
A1 = create_from_coo(row1, col1, val1)
row2 = torch.tensor([0, 1, 1]).to(dev)
col2 = torch.tensor([0, 2, 1]).to(dev)
val2 = torch.randn(len(row2), H).to(dev)
A2 = create_from_coo(row2, col2, val2)
sparse_result = get_adj(bspspmm(A1, A2))
dense_A1 = get_adj(A1)
dense_A2 = get_adj(A2)
dense_result = torch.stack([
dense_A1[:, :, i] @ dense_A2[:, :, i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
def test_batch_sparse_diag_mm():
H = 4
dev = F.ctx()
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row), H).to(dev)
A = create_from_coo(row, col, val1)
val2 = torch.randn(2, H).to(dev)
D = diag(val2, (2, 3))
sparse_result = get_adj(bspspmm(A, D))
dense_A = get_adj(A)
dense_D = get_adj(D.as_sparse())
dense_result = torch.stack([
dense_A[:, :, i] @ dense_D[:, :, i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
def test_batch_diag_dense_mm():
dev = F.ctx()
H = 4
# X: shape (N, F, H)
val = torch.randn(3, H).to(dev)
D = diag(val)
X = torch.randn(3, 2, H).to(dev)
sparse_result = bspmm(D, X)
dense_D = get_adj(D.as_sparse())
dense_result = torch.stack([
dense_D[:, :, i] @ X[..., i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
# X: shape (N, H)
X = torch.randn(3, H).to(dev)
sparse_result = bspmm(D, X)
dense_D = get_adj(D.as_sparse())
dense_result = torch.stack([
dense_D[:, :, i] @ X[..., i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
def test_batch_diag_sparse_mm():
dev = F.ctx()
H = 4
row = torch.tensor([0, 1, 1]).to(dev)
col = torch.tensor([1, 0, 1]).to(dev)
val1 = torch.randn(len(row), H).to(dev)
A = create_from_coo(row, col, val1)
val2 = torch.randn(2, H).to(dev)
D = diag(val2, (3, 2))
sparse_result = get_adj(bspspmm(D, A))
dense_A = get_adj(A)
dense_D = get_adj(D.as_sparse())
dense_result = torch.stack([
dense_D[:, :, i] @ dense_A[:, :, i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
def test_batch_diag_diag_mm():
dev = F.ctx()
H = 4
# D1, D2: shape (N, N)
val1 = torch.randn(3, H).to(dev)
D1 = diag(val1)
val2 = torch.randn(3, H).to(dev)
D2 = diag(val2)
M1 = bspspmm(D1, D2)
assert M1.shape == (3, 3)
assert torch.allclose(M1.val, val1 * val2)
# D1: shape (N, M), D2: shape (M, P)
N = 3
M = 4
P = 2
val1 = torch.randn(N, H).to(dev)
D1 = diag(val1, (N, M))
val2 = torch.randn(P, H).to(dev)
D2 = diag(val2, (M, P))
sparse_result = get_adj(bspspmm(D1, D2).as_sparse())
dense_D1 = get_adj(D1.as_sparse())
dense_D2 = get_adj(D2.as_sparse())
dense_result = torch.stack([
dense_D1[:, :, i] @ dense_D2[:, :, i] for i in range(H)
], dim=-1)
assert torch.allclose(sparse_result, dense_result)
import pytest
import torch
import backend as F
from dgl.convert import graph
from dgl.mock_sparse import diag, create_from_coo
from dgl.ops import edge_softmax
@pytest.mark.parametrize('val_shape', [(3,), (3, 2)])
@pytest.mark.parametrize('mat_shape', [(3, 3), (5, 3)])
def test_neg_diag(val_shape, mat_shape):
val = torch.randn(val_shape).to(F.ctx())
mat = diag(val, mat_shape)
neg_mat = -mat
assert neg_mat.shape == mat.shape
assert torch.allclose(-mat.val, neg_mat.val)
def test_inv_diag():
val = torch.randn(3).to(F.ctx())
mat = diag(val, (3, 3))
inv_mat = mat.inv()
assert inv_mat.shape == mat.shape
assert torch.allclose(1. / mat.val, inv_mat.val)
@pytest.mark.parametrize('val_shape', [(3,), (3, 2)])
@pytest.mark.parametrize('mat_shape', [(3, 3), (5, 3)])
def test_softmax_diag(val_shape, mat_shape):
val = torch.randn(val_shape).to(F.ctx())
mat = diag(val, mat_shape)
softmax_mat = mat.softmax()
assert softmax_mat.shape == mat.shape
assert torch.allclose(softmax_mat.val, torch.ones_like(mat.val))
@pytest.mark.parametrize('val_shape', [(3,), (3, 2)])
@pytest.mark.parametrize('mat_shape', [(4, 4), (5, 4)])
def test_neg_sp(val_shape, mat_shape):
device = F.ctx()
row = torch.tensor([1, 1, 3]).to(device)
col = torch.tensor([1, 2, 3]).to(device)
val = torch.randn(val_shape).to(device)
mat = create_from_coo(row, col, val, mat_shape)
neg_mat = -mat
assert neg_mat.shape == mat.shape
assert torch.allclose(-mat.val, neg_mat.val)
def test_inv_sp():
device = F.ctx()
row = torch.tensor([0, 1, 1]).to(device)
col = torch.tensor([0, 0, 1]).to(device)
val = torch.tensor([1., 1., 2.]).to(device)
mat = create_from_coo(row, col, val)
inv_mat = mat.inv()
assert inv_mat.shape == mat.shape
assert torch.allclose(torch.tensor([1., -0.5, 0.5]).to(device), inv_mat.val)
@pytest.mark.parametrize('val_shape', [(4,), (4, 2)])
def test_softmax_sp(val_shape):
device = F.ctx()
row = torch.tensor([0, 0, 1, 2]).to(device)
col = torch.tensor([1, 2, 2, 0]).to(device)
val = torch.randn(val_shape).to(device)
mat = create_from_coo(row, col, val)
result = mat.softmax()
assert result.shape == mat.shape
g = graph((mat.col, mat.row))
assert torch.allclose(result.val, edge_softmax(g, mat.val))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment