Unverified Commit db64cc37 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add Transposition for SparseMatrix and DiagMatrix (#4597)



* [Sparse] Add Transposition

* Fix docstring

* Fix linting problem

* Minor fix

* Minor fix
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent d04534c9
......@@ -13,7 +13,7 @@ Sparse matrix class
.. currentmodule:: dgl.mock_sparse
.. autoclass:: SparseMatrix
:members: shape, nnz, dtype, device, row, col, val, __call__, indices, coo, csr, csc, dense
:members: shape, nnz, dtype, device, row, col, val, __call__, indices, coo, csr, csc, dense, t, T, transpose
.. autosummary::
:toctree: ../../generated/
......@@ -27,7 +27,7 @@ Diagonal matrix class
.. currentmodule:: dgl.mock_sparse
.. autoclass:: DiagMatrix
:members: val, shape, __call__, nnz, dtype, device, as_sparse
:members: val, shape, __call__, nnz, dtype, device, as_sparse, t, T, transpose
.. autosummary::
:toctree: ../../generated/
......
......@@ -123,6 +123,35 @@ class DiagMatrix:
row = col = torch.arange(len(self.val)).to(self.device)
return create_from_coo(row=row, col=col, val=self.val, shape=self.shape)
def t(self):
"""Alias of :meth:`transpose()`"""
return self.transpose()
@property
def T(self): # pylint: disable=C0103
"""Alias of :meth:`transpose()`"""
return self.transpose()
def transpose(self):
"""Return the transpose of the matrix.
Returns
-------
DiagMatrix
The transpose of the matrix.
Example
--------
>>> val = torch.arange(1, 5).float()
>>> mat = diag(val, shape=(4, 5))
>>> mat = mat.transpose()
>>> print(mat)
DiagMatrix(val=tensor([1., 2., 3., 4.]),
shape=(5, 4))
"""
return DiagMatrix(self.val, self.shape[::-1])
def diag(val: torch.Tensor, shape: Optional[Tuple[int, int]] = None) -> DiagMatrix:
"""Create a diagonal matrix based on the diagonal values
......
......@@ -236,6 +236,39 @@ class SparseMatrix:
"""
return self.adj.to_dense()
def t(self):
"""Alias of :meth:`transpose()`"""
return self.transpose()
@property
def T(self): # pylint: disable=C0103
"""Alias of :meth:`transpose()`"""
return self.transpose()
def transpose(self):
"""Return the transpose of this sparse matrix.
Returns
-------
SparseMatrix
The transpose of this sparse matrix.
Example
-------
>>> row = torch.tensor([1, 1, 3])
>>> col = torch.tensor([2, 1, 3])
>>> val = torch.tensor([1, 1, 2])
>>> A = create_from_coo(row, col, val)
>>> A = A.transpose()
>>> print(A)
SparseMatrix(indices=tensor([[1, 2, 3],
[1, 1, 3]]),
values=tensor([1, 1, 2]),
shape=(4, 4), nnz=3)
"""
return SparseMatrix(self.col, self.row, self.val, self.shape[::-1])
def create_from_coo(row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None,
......
import pytest
import torch
from dgl.mock_sparse import diag, create_from_coo
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
val = torch.randn(val_shape)
mat = diag(val, mat_shape).transpose()
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape[::-1]
@pytest.mark.parametrize("dense_dim", [None, 2])
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("mat_shape", [(3, 5), (5, 3)])
def test_sparse_matrix_transpose(dense_dim, row, col, mat_shape):
# Skip invalid matrices
if max(row) >= mat_shape[0] or max(col) >= mat_shape[1]:
return
val_shape = (len(row),)
if dense_dim is not None:
val_shape += (dense_dim,)
val = torch.randn(val_shape)
row = torch.tensor(row)
col = torch.tensor(col)
mat = create_from_coo(row, col, val, mat_shape).transpose()
assert mat.shape == mat_shape[::-1]
assert torch.allclose(mat.val, val)
assert torch.allclose(mat.row, col)
assert torch.allclose(mat.col, row)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment