Unverified Commit bb7b3c6f authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add SparseMatrix unittest and fix docstring problem (#4627)



* [Sparse] Add SparseMatrix unittest and fix docstring problem

* Minor fix

* Update

* check permission

* rm future annonations

* Skip create_from_csr and create_from_csc tests because Pytorch 1.9.0 does not have torch.sparse_csr_tensor
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent 2accfc5d
"""dgl diagonal matrix module."""
"""DGL diagonal matrix module."""
from typing import Optional, Tuple
import torch
from .sp_matrix import SparseMatrix, create_from_coo
class DiagMatrix:
"""Diagonal Matrix Class
......@@ -23,21 +24,27 @@ class DiagMatrix:
shape : tuple[int, int]
Shape of the matrix.
"""
def __init__(self, val: torch.Tensor, shape: Optional[Tuple[int, int]] = None):
def __init__(
self, val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
):
len_val = len(val)
if shape is not None:
assert len_val == min(shape), \
f'Expect len(val) to be min(shape), got {len_val} for len(val) and {shape} for shape.'
assert len_val == min(shape), (
f"Expect len(val) to be min(shape), got {len_val} for len(val)"
"and {shape} for shape."
)
else:
shape = (len_val, len_val)
self.val = val
self.shape = shape
def __repr__(self):
return f'DiagMatrix(val={self.val}, \nshape={self.shape})'
return f"DiagMatrix(val={self.val}, \nshape={self.shape})"
def __call__(self, x: torch.Tensor):
"""Create a new diagonal matrix with the same shape as self but different values.
"""Create a new diagonal matrix with the same shape as self
but different values.
Parameters
----------
......@@ -152,7 +159,10 @@ class DiagMatrix:
"""
return DiagMatrix(self.val, self.shape[::-1])
def diag(val: torch.Tensor, shape: Optional[Tuple[int, int]] = None) -> DiagMatrix:
def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
) -> DiagMatrix:
"""Create a diagonal matrix based on the diagonal values
Parameters
......@@ -200,10 +210,13 @@ def diag(val: torch.Tensor, shape: Optional[Tuple[int, int]] = None) -> DiagMatr
# NOTE(Mufei): this may not be needed if DiagMatrix is simple enough
return DiagMatrix(val, shape)
def identity(shape: Tuple[int, int],
def identity(
shape: Tuple[int, int],
d: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> DiagMatrix:
device: Optional[torch.device] = None,
) -> DiagMatrix:
"""Create a diagonal matrix with ones on the diagonal and zeros elsewhere
Parameters
......@@ -251,14 +264,11 @@ def identity(shape: Tuple[int, int],
Case3: 3-by-3 matrix with tensor diagonal values
>>> mat = identity(shape=(3, 3), d=2)
>>> mat.val
tensor([[1., 1.],
>>> print(mat)
DiagMatrix(val=tensor([[1., 1.],
[1., 1.],
[1., 1.]])
>>> mat.shape
(3, 3)
>>> mat.nnz
3
[1., 1.]]),
shape=(3, 3))
"""
len_val = min(shape)
if d is None:
......
"""dgl sparse matrix module."""
"""DGL sparse matrix module."""
from typing import Optional, Tuple
import torch
__all__ = ['SparseMatrix', 'create_from_coo', 'create_from_csr', 'create_from_csc']
__all__ = [
"SparseMatrix",
"create_from_coo",
"create_from_csr",
"create_from_csc",
]
class SparseMatrix:
r'''Class for sparse matrix.
r"""Class for sparse matrix.
Parameters
----------
......@@ -28,30 +34,31 @@ class SparseMatrix:
>>> dst = torch.tensor([2, 4, 3])
>>> val = torch.tensor([1, 1, 1])
>>> A = SparseMatrix(src, dst, val)
>>> A.shape
(3,5)
>>> A.row
tensor([1, 1, 2])
>>> A.val
tensor([1., 1., 1.])
>>> A.nnz
3
>>> print(A)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1, 1, 1]),
shape=(3, 5), nnz=3)
Case2: Sparse matrix with row indices, col indices and values (vector).
>>> ...
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3]])
>>> A = SparseMatrix(src, dst, val)
>>> A.val
tensor([[1, 1],
>>> print(A)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([[1, 1],
[2, 2],
[3, 3]])
'''
def __init__(self,
[3, 3]]),
shape=(3, 5), nnz=3)
"""
def __init__(
self,
row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape : Optional[Tuple[int, int]] = None
shape: Optional[Tuple[int, int]] = None,
):
if val is None:
val = torch.ones(row.shape[0])
......@@ -112,7 +119,7 @@ class SparseMatrix:
return self.adj.device
@property
def row(self) -> torch.tensor:
def row(self) -> torch.Tensor:
"""Get the row indices of the nonzero elements.
Returns
......@@ -123,7 +130,7 @@ class SparseMatrix:
return self.adj.indices()[0]
@property
def col(self) -> torch.tensor:
def col(self) -> torch.Tensor:
"""Get the column indices of the nonzero elements.
Returns
......@@ -134,7 +141,7 @@ class SparseMatrix:
return self.adj.indices()[1]
@property
def val(self) -> torch.tensor:
def val(self) -> torch.Tensor:
"""Get the values of the nonzero elements.
Returns
......@@ -145,16 +152,18 @@ class SparseMatrix:
return self.adj.values()
@val.setter
def val(self, x) -> torch.tensor:
def val(self, x: torch.Tensor) -> torch.Tensor:
"""Set the values of the nonzero elements."""
assert len(x) == self.nnz
if len(x.shape) == 1:
shape = self.shape
else:
shape = self.shape + (x.shape[-1],)
self.adj = torch.sparse_coo_tensor(self.adj.indices(), x, shape).coalesce()
self.adj = torch.sparse_coo_tensor(
self.adj.indices(), x, shape
).coalesce()
def __call__(self, x):
def __call__(self, x: torch.Tensor):
"""Create a new sparse matrix with the same sparsity as self but different values.
Parameters
......@@ -171,7 +180,9 @@ class SparseMatrix:
assert len(x) == self.nnz
return SparseMatrix(self.row, self.col, x, shape=self.shape)
def indices(self, fmt, return_shuffle=False) -> Tuple[torch.tensor, ...]:
def indices(
self, fmt: str, return_shuffle=False
) -> Tuple[torch.Tensor, ...]:
"""Get the indices of the nonzero elements.
Parameters
......@@ -186,12 +197,12 @@ class SparseMatrix:
tensor
Indices of the nonzero elements
"""
if fmt == 'COO' and not return_shuffle:
if fmt == "COO" and not return_shuffle:
return self.adj.indices()
else:
raise NotImplementedError
def coo(self) -> Tuple[torch.tensor, ...]:
def coo(self) -> Tuple[torch.Tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix.
Returns
......@@ -201,7 +212,7 @@ class SparseMatrix:
"""
return self
def csr(self) -> Tuple[torch.tensor, ...]:
def csr(self) -> Tuple[torch.Tensor, ...]:
"""Get the CSR (Compressed Sparse Row) representation of the sparse matrix.
Returns
......@@ -211,7 +222,7 @@ class SparseMatrix:
"""
return self
def csc(self) -> Tuple[torch.tensor, ...]:
def csc(self) -> Tuple[torch.Tensor, ...]:
"""Get the CSC (Compressed Sparse Column) representation of the sparse matrix.
Returns
......@@ -221,7 +232,7 @@ class SparseMatrix:
"""
return self
def dense(self) -> torch.tensor:
def dense(self) -> torch.Tensor:
"""Get the dense representation of the sparse matrix.
Returns
......@@ -264,10 +275,13 @@ class SparseMatrix:
"""
return SparseMatrix(self.col, self.row, self.val, self.shape[::-1])
def create_from_coo(row: torch.Tensor,
def create_from_coo(
row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None) -> SparseMatrix:
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from row and column coordinates.
Parameters
......@@ -324,10 +338,13 @@ def create_from_coo(row: torch.Tensor,
"""
return SparseMatrix(row=row, col=col, val=val, shape=shape)
def create_from_csr(indptr: torch.Tensor,
def create_from_csr(
indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None) -> SparseMatrix:
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from CSR indices.
For row i of the sparse matrix
......@@ -366,38 +383,49 @@ def create_from_csr(indptr: torch.Tensor,
>>> indptr = torch.tensor([0, 1, 2, 5])
>>> indices = torch.tensor([1, 2, 0, 1, 2])
>>> A = create_from_csr(indptr, indices)
>>> A
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(3, 3), nnz=5)
>>> # Specify shape
>>> A = create_from_csr(indptr, indices, shape=(5, 3))
>>> A.shape
(5, 3)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 3), nnz=5)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csr(indptr, indices, val)
>>> A.val
tensor([[1, 1],
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
[5, 5]]),
shape=(3, 3), nnz=5)
"""
adj_csr = torch.sparse_csr_tensor(indptr, indices, torch.ones(indices.shape[0]))
adj_csr = torch.sparse_csr_tensor(
indptr, indices, torch.ones(indices.shape[0])
)
adj_coo = adj_csr.to_sparse_coo().coalesce()
row, col = adj_coo.indices()
return SparseMatrix(row=row, col=col, val=val, shape=shape)
def create_from_csc(indptr: torch.Tensor,
def create_from_csc(
indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None) -> SparseMatrix:
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from CSC indices.
For column i of the sparse matrix
......@@ -436,29 +464,37 @@ def create_from_csc(indptr: torch.Tensor,
>>> indptr = torch.tensor([0, 1, 3, 5])
>>> indices = torch.tensor([2, 0, 2, 1, 2])
>>> A = create_from_csc(indptr, indices)
>>> A
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(3, 3), nnz=5)
>>> # Specify shape
>>> A = create_from_csc(indptr, indices, shape=(5, 3))
>>> A.shape
(5, 3)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 3), nnz=5)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csc(indptr, indices, val)
>>> A.val
tensor([[1, 1],
[2, 2],
[3, 3],
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([[2, 2],
[4, 4],
[5, 5]])
[1, 1],
[3, 3],
[5, 5]]),
shape=(3, 3), nnz=5)
"""
adj_csr = torch.sparse_csr_tensor(indptr, indices, torch.ones(indices.shape[0]))
adj_csr = torch.sparse_csr_tensor(
indptr, indices, torch.ones(indices.shape[0])
)
adj_coo = adj_csr.to_sparse_coo().coalesce()
col, row = adj_coo.indices()
......
import pytest
import torch
from dgl.mock_sparse import create_from_coo, create_from_csr, create_from_csc
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_create_from_coo(dense_dim, row, col, mat_shape):
# Skip invalid matrices
if mat_shape is not None and (
max(row) >= mat_shape[0] or max(col) >= mat_shape[1]
):
return
val_shape = (len(row),)
if dense_dim is not None:
val_shape += (dense_dim,)
val = torch.randn(val_shape)
row = torch.tensor(row)
col = torch.tensor(col)
mat = create_from_coo(row, col, val, mat_shape)
if mat_shape is None:
mat_shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)
assert mat.shape == mat_shape
assert mat.nnz == row.numel()
assert mat.dtype == val.dtype
assert torch.allclose(mat.val, val)
assert torch.allclose(mat.row, row)
assert torch.allclose(mat.col, col)
@pytest.mark.skip(reason="no way of currently testing this")
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("indptr", [[0, 0, 1, 4], (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5)])
def test_create_from_csr(dense_dim, indptr, indices, mat_shape):
val_shape = (len(indices),)
if dense_dim is not None:
val_shape += (dense_dim,)
val = torch.randn(val_shape)
indptr = torch.tensor(indptr)
indices = torch.tensor(indices)
mat = create_from_csr(indptr, indices, val, mat_shape)
if mat_shape is None:
mat_shape = (indptr.numel() - 1, torch.max(indices).item() + 1)
assert mat.device == val.device
assert mat.shape == mat_shape
assert mat.nnz == indices.numel()
assert mat.dtype == val.dtype
assert torch.allclose(mat.val, val)
deg = torch.diff(indptr)
row = torch.repeat_interleave(torch.arange(deg.numel()), deg)
assert torch.allclose(mat.row, row)
col = indices
assert torch.allclose(mat.col, col)
@pytest.mark.skip(reason="no way of currently testing this")
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("indptr", [[0, 0, 1, 4], (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("mat_shape", [None, (5, 3)])
def test_create_from_csc(dense_dim, indptr, indices, mat_shape):
val_shape = (len(indices),)
if dense_dim is not None:
val_shape += (dense_dim,)
val = torch.randn(val_shape)
indptr = torch.tensor(indptr)
indices = torch.tensor(indices)
mat = create_from_csc(indptr, indices, val, mat_shape)
if mat_shape is None:
mat_shape = (torch.max(indices).item() + 1, indptr.numel() - 1)
assert mat.device == val.device
assert mat.shape == mat_shape
assert mat.nnz == indices.numel()
assert mat.dtype == val.dtype
assert torch.allclose(mat.val, val)
row = indices
assert torch.allclose(mat.row, row)
deg = torch.diff(indptr)
col = torch.repeat_interleave(torch.arange(deg.numel()), deg)
assert torch.allclose(mat.col, col)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment