Unverified Commit cf035927 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

Move mock version of dgl_sparse library to DGL main repo (#4524)



* init

* Add api doc for sparse library

* support op btwn matrices with differnt sparsity

* Fixed docstring

* addresses comments

* lint check

* change keyword format to fmt
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent 701c6fe7
.. _apibackend:
dgl.mock_sparse
=================================
`dgl_sparse` is a library for sparse operators that are commonly used in GNN models.
.. warning::
This is an experimental package. The sparse operators provided in this library do not guarantee the same performance as their message-passing api counterparts.
Sparse matrix class
-------------------------
.. currentmodule:: dgl.mock_sparse
.. autoclass:: SparseMatrix
:members:
\ No newline at end of file
"""dgl sparse class."""
from .sp_matrix import *
from .elementwise_op_sp import *
"""dgl elementwise operators for sparse matrix module."""
import torch
from .sp_matrix import SparseMatrix
__all__ = ['add', 'sub', 'mul', 'div', 'rdiv', 'power', 'rpower']
def add(A, B):
"""Elementwise addition.
Parameters
----------
A : SparseMatrix
Sparse matrix
B : SparseMatrix
Sparse matrix
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case 1: Add two matrices of same sparsity structure
>>> rowA = torch.tensor([1, 0, 2, 7, 1])
>>> colA = torch.tensor([0, 49, 2, 1, 7])
>>> valA = torch.tensor([10, 20, 30, 40, 50])
>>> A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
>>> A + A
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([ 40, 20, 100, 60, 80]),
shape=(10, 50), nnz=5)
>>> w = torch.arange(1, len(rowA)+1)
>>> A + A(w)
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([21, 12, 53, 34, 45]),
shape=(10, 50), nnz=5)
Case 2: Add two matrices of different sparsity structure
>>> rowB = torch.tensor([1, 9, 2, 7, 1, 1, 0])
>>> colB = torch.tensor([0, 1, 2, 1, 7, 11, 15])
>>> valB = torch.tensor([1, 2, 3, 4, 5, 6])
>>> B = SparseMatrix(rowB, colB, valB, shape=(10, 50))
>>> A + B
SparseMatrix(indices=tensor([[ 0, 1, 1, 1, 2, 7, 9],
[49, 0, 7, 11, 2, 1, 1]]),
values=tensor([20, 11, 55, 6, 33, 44, 2]),
shape=(10, 50), nnz=7)
"""
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
assert A.shape == B.shape, 'The shape of sparse matrix A {} and' \
' B {} are expected to match'.format(A.shape, B.shape)
C = A.adj + B.adj
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
raise RuntimeError('Elementwise addition between {} and {} is not ' \
'supported.'.format(type(A), type(B)))
def sub(A, B):
"""Elementwise subtraction.
Parameters
----------
A : SparseMatrix
Sparse matrix
B : SparseMatrix
Sparse matrix
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> rowA = torch.tensor([1, 0, 2, 7, 1])
>>> colA = torch.tensor([0, 49, 2, 1, 7])
>>> valA = torch.tensor([10, 20, 30, 40, 50])
>>> A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
>>> rowB = torch.tensor([1, 9, 2, 7, 1, 1])
>>> colB = torch.tensor([0, 1, 2, 1, 7, 11])
>>> valB = torch.tensor([1, 2, 3, 4, 5, 6])
>>> B = SparseMatrix(rowB, colB, valB, shape=(10, 50))
>>> A - B
SparseMatrix(indices=tensor([[ 0, 1, 1, 1, 2, 7, 9],
[49, 0, 7, 11, 2, 1, 1]]),
values=tensor([20, 9, 45, -6, 27, 36, -2]),
shape=(10, 50), nnz=7
"""
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
assert A.shape == B.shape, 'The shape of sparse matrix A {} and' \
' B {} are expected to match.'.format(A.shape, B.shape)
C = A.adj - B.adj
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
raise RuntimeError('Elementwise subtraction between {} and {} is not ' \
'supported.'.format(type(A), type(B)))
def mul(A, B):
"""Elementwise multiplication.
Parameters
----------
A : SparseMatrix or scalar
Sparse matrix or scalar value
B : SparseMatrix or scalar
Sparse matrix or scalar value.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case 1: Elementwise multiplication between two sparse matrices
>>> rowA = torch.tensor([1, 0, 2, 7, 1])
>>> colA = torch.tensor([0, 49, 2, 1, 7])
>>> valA = torch.tensor([10, 20, 30, 40, 50])
>>> A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
>>> rowB = torch.tensor([1, 9, 2, 7, 1, 1])
>>> colB = torch.tensor([0, 1, 2, 1, 7, 11])
>>> valB = torch.tensor([1, 2, 3, 4, 5, 6])
>>> B = SparseMatrix(rowB, colB, valB, shape=(10, 50))
>>> A * B
SparseMatrix(indices=tensor([[1, 1, 2, 7],
[0, 7, 2, 1]]),
values=tensor([ 10, 250, 90, 160]),
shape=(10, 50), nnz=4)
Case 2: Elementwise multiplication between sparse matrix and scalar
>>> v_scalar = 2.5
>>> A * v_scalar
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([ 50., 25., 125., 75., 100.]),
shape=(8, 50), nnz=5)
>>> v_scalar * A
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([ 50., 25., 125., 75., 100.]),
shape=(8, 50), nnz=5)
"""
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
assert A.shape == B.shape, 'The shape of sparse matrix A {} and' \
' B {} are expected to match.'.format(A.shape, B.shape)
A = A.adj if isinstance(A, SparseMatrix) else A
B = B.adj if isinstance(B, SparseMatrix) else B
C = A * B
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
def div(A, B):
"""Elementwise division.
Parameters
----------
A : SparseMatrix
Sparse matrix
B : SparseMatrix or scalar
Sparse matrix or scalar value.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case 1: Elementwise division between two matrices of same sparsity (matrices
with different sparsity is not supported)
>>> rowA = torch.tensor([1, 0, 2, 7, 1])
>>> colA = torch.tensor([0, 49, 2, 1, 7])
>>> valA = torch.tensor([10, 20, 30, 40, 50])
>>> A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
>>> w = torch.arange(1, len(rowA)+1)
>>> A/A(w)
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([20.0000, 5.0000, 16.6667, 7.5000, 8.0000]),
shape=(8, 50), nnz=5)
Case 2: Elementwise multiplication between sparse matrix and scalar
>>> A / v_scalar
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([ 8., 4., 20., 12., 16.]),
shape=(8, 50), nnz=5)
"""
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
# same sparsity structure
if torch.equal(A.indices("COO"), B.indices("COO")):
return SparseMatrix(A.row, A.col, A.val / B.val, A.shape)
raise ValueError('Division between matrices of different sparsity is not supported')
C = A.adj/B
return SparseMatrix(C.indices()[0], C.indices()[1], C.values(), C.shape)
def rdiv(A, B):
"""Elementwise division.
Parameters
----------
A : scalar
scalar value
B : SparseMatrix
Sparse matrix
"""
raise RuntimeError('Elementwise division between {} and {} is not ' \
'supported.'.format(type(A), type(B)))
def power(A, B):
"""Elementwise power operation.
Parameters
----------
A : SparseMatrix
Sparse matrix
B : scalar
scalar value.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> rowA = torch.tensor([1, 0, 2, 7, 1])
>>> colA = torch.tensor([0, 49, 2, 1, 7])
>>> valA = torch.tensor([10, 20, 30, 40, 50])
>>> A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
>>> pow(A, 2.5)
SparseMatrix(indices=tensor([[ 0, 1, 1, 2, 7],
[49, 0, 7, 2, 1]]),
values=tensor([ 1788.8544, 316.2278, 17677.6699, 4929.5029, 10119.2881]),
shape=(8, 50), nnz=5)
"""
if isinstance(B, SparseMatrix):
raise RuntimeError('Power operation between two sparse matrices is not supported')
return SparseMatrix(A.row, A.col, torch.pow(A.val, B))
def rpower(A, B):
"""Elementwise power operation.
Parameters
----------
A : scalar
scalar value.
B : SparseMatrix
Sparse matrix.
"""
raise RuntimeError('Power operation between {} and {} is not ' \
'supported.'.format(type(A), type(B)))
SparseMatrix.__add__ = add
SparseMatrix.__radd__ = add
SparseMatrix.__sub__ = sub
SparseMatrix.__rsub__ = sub
SparseMatrix.__mul__ = mul
SparseMatrix.__rmul__ = mul
SparseMatrix.__truediv__ = div
SparseMatrix.__rtruediv__ = rdiv
SparseMatrix.__pow__ = power
SparseMatrix.__rpow__ = rpower
"""dgl sparse matrix module."""
from typing import Optional, Tuple
import torch
__all__ = ['SparseMatrix', 'create_from_coo', 'create_from_csr', 'create_from_csc']
class SparseMatrix:
r'''Class for sparse matrix.
Parameters
----------
row : tensor
The row indices of shape nnz.
col : tensor
The column indices of shape nnz.
val : tensor, optional
The values of shape (nnz, *). If None, it will be a tensor of shape (nnz)
filled by 1.
shape : tuple[int, int], optional
Shape or size of the sparse matrix. If not provided the shape will be
inferred from the row and column indices.
Examples
--------
Case1: Sparse matrix with row indices, col indices and values (scalar).
>>> src = torch.tensor([1, 1, 2])
>>> dst = torch.tensor([2, 4, 3])
>>> val = torch.tensor([1, 1, 1])
>>> A = SparseMatrix(src, dst, val)
>>> A.shape
(3,5)
>>> A.row
tensor([1, 1, 2])
>>> A.val
tensor([1., 1., 1.])
>>> A.nnz
3
Case2: Sparse matrix with row indices, col indices and values (vector).
>>> ...
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3]])
>>> A = SparseMatrix(src, dst, val)
>>> A.val
tensor([[1, 1],
[2, 2],
[3, 3]])
'''
def __init__(self,
row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape : Optional[Tuple[int, int]] = None
):
self._row = row
self._col = col
if val is None:
val = torch.ones(row.shape[0])
self._val = val
i = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
if shape is not None:
self.adj = torch.sparse_coo_tensor(i, val, shape).coalesce()
else:
self.adj = torch.sparse_coo_tensor(i, val).coalesce()
def __repr__(self):
return f'SparseMatrix(indices={self.indices("COO")}, \nvalues={self.val}, \
\nshape={self.shape}, nnz={self.nnz})'
@property
def shape(self) -> Tuple[int, ...]:
"""Shape of the sparse matrix.
Returns
-------
tuple[int]
The shape of the matrix
"""
return (self.adj.shape[0], self.adj.shape[1])
@property
def nnz(self) -> int:
"""The number of nonzero elements of the sparse matrix.
Returns
-------
int
The number of nonzero elements of the matrix
"""
return self.adj._nnz()
@property
def dtype(self) -> torch.dtype:
"""Data type of the values of the sparse matrix.
Returns
-------
torch.dtype
Data type of the values of the matrix
"""
return self.adj.dtype
@property
def device(self) -> torch.device:
"""Device of the sparse matrix.
Returns
-------
torch.device
Device of the matrix
"""
return self.adj.device
@property
def row(self) -> torch.tensor:
"""Get the row indices of the nonzero elements.
Returns
-------
tensor
Row indices of the nonzero elements
"""
return self.adj.indices()[0]
@property
def col(self) -> torch.tensor:
"""Get the column indices of the nonzero elements.
Returns
-------
tensor
Column indices of the nonzero elements
"""
return self.adj.indices()[1]
@property
def val(self) -> torch.tensor:
"""Get the values of the nonzero elements.
Returns
-------
tensor
Values of the nonzero elements
"""
return self.adj.values()
@val.setter
def val(self, x) -> torch.tensor:
"""Set the values of the nonzero elements."""
assert len(x) == self.nnz
self._val = x
if len(x.shape) == 1:
shape = self.shape
else:
shape = self.shape + (x.shape[-1],)
self.adj = torch.sparse_coo_tensor(self.adj.indices(), x, shape).coalesce()
def __call__(self, x):
"""Create a new sparse matrix with the same sparsity as self but different values.
Parameters
----------
x : tensor
Values of the new sparse matrix
Returns
-------
Class object
A new sparse matrix object of the SparseMatrix class
"""
assert len(x) == self.nnz
return SparseMatrix(self.row, self.col, x, shape=self.shape)
def indices(self, fmt, return_shuffle=False) -> Tuple[torch.tensor, ...]:
"""Get the indices of the nonzero elements.
Parameters
----------
fmt : str
Sparse matrix storage format. Can be COO or CSR or CSC.
return_shuffle: bool
If true, return an extra array of the nonzero value IDs
Returns
-------
tensor
Indices of the nonzero elements
"""
if fmt == 'COO' and not return_shuffle:
return self.adj.indices()
else:
raise NotImplementedError
def coo(self) -> Tuple[torch.tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix.
Returns
-------
tensor
A tensor containing indices and value tensors.
"""
return self
def csr(self) -> Tuple[torch.tensor, ...]:
"""Get the CSR (Compressed Sparse Row) representation of the sparse matrix.
Returns
-------
tensor
A tensor containing compressed row pointers, column indices and value tensors.
"""
return self
def csc(self) -> Tuple[torch.tensor, ...]:
"""Get the CSC (Compressed Sparse Column) representation of the sparse matrix.
Returns
-------
tensor
A tensor containing compressed column pointers, row indices and value tensors.
"""
return self
def dense(self) -> torch.tensor:
"""Get the dense representation of the sparse matrix.
Returns
-------
tensor
Dense representation of the sparse matrix.
"""
return self.adj.to_dense()
def create_from_coo(row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None) -> SparseMatrix:
"""Create a sparse matrix from row and column coordinates.
Parameters
----------
row : tensor
The row indices of shape nnz.
col : tensor
The column indices of shape nnz.
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix with row and column indices without values.
>>> src = torch.tensor([1, 1, 2])
>>> dst = torch.tensor([2, 4, 3])
>>> A = create_from_coo(src, dst)
>>> A
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3]])
>>> A = create_from_coo(src, dst, val)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([[1, 1],
[2, 2],
[3, 3]]),
shape=(3, 5), nnz=3)
"""
return SparseMatrix(row, col, val)
def create_from_csr(indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None) -> SparseMatrix:
"""Create a sparse matrix from CSR indices.
For row i of the sparse matrix
- the column indices of the nonzero entries are stored in ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
Parameters
----------
indptr : tensor
Pointer to the column indices of shape N + 1, where N is the number of rows.
indices : tensor
The column indices of shape nnz.
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix without values
[[0, 1, 0],
[0, 0, 1],
[1, 1, 1]]
>>> indptr = torch.tensor([0, 1, 2, 5])
>>> indices = torch.tensor([1, 2, 0, 1, 2])
>>> A = create_from_csr(indptr, indices)
>>> A.shape
(3, 3)
>>> A.row
tensor([0, 1, 2, 2, 2])
>>> A.val
tensor([1., 1., 1., 1., 1.])
>>> A.nnz
5
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csr(indptr, indices, val)
>>> A.val
tensor([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
"""
adj_csr = torch.sparse_csr_tensor(indptr, indices, torch.ones(indices.shape[0]))
adj_coo = adj_csr.to_sparse_coo().coalesce()
row, col = adj_coo.indices()
return SparseMatrix(row, col, val)
def create_from_csc(indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None) -> SparseMatrix:
"""Create a sparse matrix from CSC indices.
For column i of the sparse matrix
- the row indices of the nonzero entries are stored in ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
Parameters
----------
indptr : tensor
Pointer to the row indices of shape N + 1, where N is the number of columns.
indices : tensor
The row indices of shape nnz.
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix without values
[[0, 1, 0],
[0, 0, 1],
[1, 1, 1]]
>>> indptr = torch.tensor([0, 1, 3, 5])
>>> indices = torch.tensor([2, 0, 2, 1, 2])
>>> A = create_from_csc(indptr, indices)
>>> A.shape
(3, 3)
>>> A.row
tensor([0, 1, 2, 2, 2])
>>> A.val
tensor([1., 1., 1., 1., 1.])
>>> A.nnz
5
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csc(indptr, indices, val)
>>> A.val
tensor([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
"""
adj_csr = torch.sparse_csr_tensor(indptr, indices, torch.ones(indices.shape[0]))
adj_coo = adj_csr.to_sparse_coo().coalesce()
col, row = adj_coo.indices()
return SparseMatrix(row, col, val)
import numpy as np
import pytest
import dgl
import dgl.backend as F
import torch
import numpy
import operator
from dgl.mock_sparse import SparseMatrix
parametrize_idtype = pytest.mark.parametrize("idtype", [F.int32, F.int64])
parametrize_dtype = pytest.mark.parametrize('dtype', [F.float32, F.float64])
def all_close_sparse(A, B):
assert torch.allclose(A.indices(), B.indices())
assert torch.allclose(A.values(), B.values())
assert A.shape == B.shape
@parametrize_idtype
@parametrize_dtype
@pytest.mark.parametrize('op', [operator.add, operator.sub, operator.mul, operator.truediv])
def test_sparse_op_sparse(idtype, dtype, op):
rowA = torch.tensor([1, 0, 2, 7, 1])
colA = torch.tensor([0, 49, 2, 1, 7])
valA = torch.rand(len(rowA))
A = SparseMatrix(rowA, colA, valA, shape=(10, 50))
w = torch.rand(len(rowA))
A1 = SparseMatrix(rowA, colA, w, shape=(10, 50))
rowB = torch.tensor([1, 9, 2, 7, 1, 1, 0])
colB = torch.tensor([0, 1, 2, 1, 7, 11, 15])
valB = torch.rand(len(rowB))
B = SparseMatrix(rowB, colB, valB, shape=(10, 50))
def _test():
if op is not operator.truediv:
all_close_sparse(op(A.adj, A1.adj), op(A, A1).adj)
all_close_sparse(op(A.adj, B.adj), op(A, B).adj)
else:
# sparse div is not supported in PyTorch
assert np.allclose(op(A, A1).val, op(A.val, A1.val), rtol=1e-4, atol=1e-4)
_test()
@parametrize_idtype
@parametrize_dtype
@pytest.mark.parametrize('v_scalar', [2, 2.5])
def test_sparse_op_scalar(idtype, dtype, v_scalar):
row = torch.randint(1, 500, (100,))
col = torch.randint(1, 500, (100,))
val = torch.rand(100)
A = SparseMatrix(row, col, val)
all_close_sparse(A.adj * v_scalar, (A * v_scalar).adj)
all_close_sparse(A.adj / v_scalar, (A / v_scalar).adj)
all_close_sparse(pow(A.adj, v_scalar), pow(A, v_scalar).adj)
@parametrize_idtype
@parametrize_dtype
@pytest.mark.parametrize('v_scalar', [2, 2.5])
def test_scalar_op_sparse(idtype, dtype, v_scalar):
row = torch.randint(1, 500, (100,))
col = torch.randint(1, 500, (100,))
val = torch.rand(100)
A = SparseMatrix(row, col, val)
all_close_sparse(v_scalar * A.adj, (v_scalar * A).adj)
def test_expose_op():
rowA = torch.tensor([1, 0, 2, 7, 1])
colA = torch.tensor([0, 49, 2, 1, 7])
A = dgl.mock_sparse.SparseMatrix(rowA, colA, shape=(10, 50))
dgl.mock_sparse.add(A, A)
dgl.mock_sparse.sub(A, A)
dgl.mock_sparse.mul(A, A)
dgl.mock_sparse.div(A, A)
if __name__ == '__main__':
test_sparse_op_sparse()
test_sparse_op_scalar()
test_scalar_op_sparse()
test_expose_op()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment