"docs/source/api/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "baa16231f323f0d785c2ab8007e0b2e49499dd35"
Unverified Commit 11e61905 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Add Two Sparse Matrices with Different Sparsity, DiagMatrix.dense,...


[Sparse] Add Two Sparse Matrices with Different Sparsity, DiagMatrix.dense, Add SparseMatrix and DiagMatrix (#5044)

* Update

* undo wrong edit

* lint

* fix import

* Update
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent f1ee3e31
...@@ -62,6 +62,16 @@ std::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr); ...@@ -62,6 +62,16 @@ std::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr);
/** @brief Convert a CSR in the sparse library to an old DGL CSR matrix. */ /** @brief Convert a CSR in the sparse library to an old DGL CSR matrix. */
aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr); aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr);
/**
* @brief Convert a COO and its nonzero values to a Torch COO matrix.
* @param coo The COO format in the sparse library
* @param value Values of the sparse matrix
*
* @return Torch Sparse Tensor in COO format
*/
torch::Tensor COOToTorchCOO(
const std::shared_ptr<COO>& coo, torch::Tensor value);
/** @brief Convert a CSR format to COO format. */ /** @brief Convert a CSR format to COO format. */
std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr); std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr);
......
...@@ -21,16 +21,14 @@ namespace sparse { ...@@ -21,16 +21,14 @@ namespace sparse {
c10::intrusive_ptr<SparseMatrix> SpSpAdd( c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) { const c10::intrusive_ptr<SparseMatrix>& B) {
auto fmt = FindAnyExistingFormat(A, B);
auto value = A->value() + B->value();
ElementwiseOpSanityCheck(A, B); ElementwiseOpSanityCheck(A, B);
if (fmt == SparseFormat::kCOO) { auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
return SparseMatrix::FromCOO(A->COOPtr(), value, A->shape()); auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
} else if (fmt == SparseFormat::kCSR) { auto sum = (torch_A + torch_B).coalesce();
return SparseMatrix::FromCSR(A->CSRPtr(), value, A->shape()); auto indices = sum.indices();
} else { auto row = indices[0];
return SparseMatrix::FromCSC(A->CSCPtr(), value, A->shape()); auto col = indices[1];
} return CreateFromCOO(row, col, sum.values(), A->shape());
} }
} // namespace sparse } // namespace sparse
......
...@@ -48,6 +48,19 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) { ...@@ -48,6 +48,19 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {
csr->num_rows, csr->num_cols, indptr, indices, data, csr->sorted); csr->num_rows, csr->num_cols, indptr, indices, data, csr->sorted);
} }
torch::Tensor COOToTorchCOO(
const std::shared_ptr<COO>& coo, torch::Tensor value) {
std::vector<torch::Tensor> indices = {coo->row, coo->col};
if (value.ndimension() == 2) {
return torch::sparse_coo_tensor(
torch::stack(indices), value,
{coo->num_rows, coo->num_cols, value.size(1)});
} else {
return torch::sparse_coo_tensor(
torch::stack(indices), value, {coo->num_rows, coo->num_cols});
}
}
std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr) { std::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr) {
auto dgl_csr = CSRToOldDGLCSR(csr); auto dgl_csr = CSRToOldDGLCSR(csr);
auto dgl_coo = aten::CSRToCOO(dgl_csr, csr->value_indices.has_value()); auto dgl_coo = aten::CSRToCOO(dgl_csr, csr->value_indices.has_value());
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from .._ffi import libinfo from .._ffi import libinfo
from .diag_matrix import * from .diag_matrix import *
from .elementwise_op import * from .elementwise_op import *
from .elementwise_op_diag import *
from .elementwise_op_sp import *
from .sparse_matrix import * from .sparse_matrix import *
from .unary_op_diag import * from .unary_op_diag import *
from .unary_op_sp import * from .unary_op_sp import *
......
...@@ -121,6 +121,22 @@ class DiagMatrix: ...@@ -121,6 +121,22 @@ class DiagMatrix:
row = col = torch.arange(len(self.val)).to(self.device) row = col = torch.arange(len(self.val)).to(self.device)
return create_from_coo(row=row, col=col, val=self.val, shape=self.shape) return create_from_coo(row=row, col=col, val=self.val, shape=self.shape)
def dense(self) -> torch.Tensor:
"""Return a dense representation of the matrix.
Returns
-------
torch.Tensor
Dense representation of the diagonal matrix.
"""
val = self.val
device = self.device
shape = self.shape + val.shape[1:]
mat = torch.zeros(shape, device=device, dtype=self.dtype)
row = col = torch.arange(len(val)).to(device)
mat[row, col] = val
return mat
def t(self): def t(self):
"""Alias of :meth:`transpose()`""" """Alias of :meth:`transpose()`"""
return self.transpose() return self.transpose()
......
...@@ -2,20 +2,43 @@ ...@@ -2,20 +2,43 @@
from typing import Union from typing import Union
from .diag_matrix import DiagMatrix from .diag_matrix import DiagMatrix
from .elementwise_op_diag import diag_add
from .elementwise_op_sp import sp_add
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
__all__ = ["add", "power"] __all__ = ["add", "power"]
def add( def add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix] A: Union[DiagMatrix, SparseMatrix], B: Union[DiagMatrix, SparseMatrix]
) -> Union[SparseMatrix, DiagMatrix]: ) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise addition""" """Elementwise addition
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return diag_add(A, B) Parameters
return sp_add(A, B) ----------
A : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
B : DiagMatrix or SparseMatrix
Diagonal matrix or sparse matrix
Returns
-------
DiagMatrix or SparseMatrix
Diagonal matrix if both :attr:`A` and :attr:`B` are diagonal matrices,
sparse matrix otherwise
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 1, 2])
>>> val = torch.tensor([10, 20, 30])
>>> A = create_from_coo(row, col, val)
>>> B = diag(torch.arange(1, 4))
>>> A + B
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 1, 0, 1, 2]]),
values=tensor([ 1, 20, 10, 2, 33]),
shape=(3, 3), nnz=5)
"""
return A + B
def power( def power(
......
"""DGL elementwise operators for diagonal matrix module.""" """DGL elementwise operators for diagonal matrix module."""
from typing import Union from typing import Union
from .diag_matrix import diag, DiagMatrix from .diag_matrix import DiagMatrix, diag
from .sparse_matrix import SparseMatrix
__all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"] __all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: def diag_add(
D1: DiagMatrix, D2: Union[DiagMatrix, SparseMatrix]
) -> Union[DiagMatrix, SparseMatrix]:
"""Elementwise addition """Elementwise addition
Parameters Parameters
---------- ----------
D1 : DiagMatrix D1 : DiagMatrix
Diagonal matrix Diagonal matrix
D2 : DiagMatrix D2 : DiagMatrix or SparseMatrix
Diagonal matrix Diagonal matrix or sparse matrix
Returns Returns
------- -------
DiagMatrix DiagMatrix or SparseMatrix
Diagonal matrix Diagonal matrix or sparse matrix, same as D2
Examples Examples
-------- --------
...@@ -35,6 +38,13 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: ...@@ -35,6 +38,13 @@ def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
f"{D1.shape} and D2 {D2.shape} must match." f"{D1.shape} and D2 {D2.shape} must match."
) )
return diag(D1.val + D2.val, D1.shape) return diag(D1.val + D2.val, D1.shape)
elif isinstance(D1, DiagMatrix) and isinstance(D2, SparseMatrix):
assert D1.shape == D2.shape, (
"The shape of diagonal matrix D1 "
f"{D1.shape} and sparse matrix D2 {D2.shape} must match."
)
D1 = D1.as_sparse()
return D1 + D2
raise RuntimeError( raise RuntimeError(
"Elementwise addition between " "Elementwise addition between "
f"{type(D1)} and {type(D2)} is not supported." f"{type(D1)} and {type(D2)} is not supported."
......
...@@ -3,6 +3,7 @@ from typing import Union ...@@ -3,6 +3,7 @@ from typing import Union
import torch import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import SparseMatrix, val_like
__all__ = ["sp_add", "sp_power"] __all__ = ["sp_add", "sp_power"]
...@@ -15,15 +16,15 @@ def spsp_add(A, B): ...@@ -15,15 +16,15 @@ def spsp_add(A, B):
) )
def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: def sp_add(A: SparseMatrix, B: Union[DiagMatrix, SparseMatrix]) -> SparseMatrix:
"""Elementwise addition. """Elementwise addition
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix Sparse matrix
B : SparseMatrix B : DiagMatrix or SparseMatrix
Sparse matrix Diagonal matrix or sparse matrix
Returns Returns
------- -------
...@@ -43,6 +44,8 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: ...@@ -43,6 +44,8 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
values=tensor([40, 20, 60]), values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
if isinstance(B, DiagMatrix):
B = B.as_sparse()
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix): if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
return spsp_add(A, B) return spsp_add(A, B)
raise RuntimeError( raise RuntimeError(
......
...@@ -136,7 +136,7 @@ class SparseMatrix: ...@@ -136,7 +136,7 @@ class SparseMatrix:
row, col = self.coo() row, col = self.coo()
val = self.val val = self.val
shape = self.shape + val.shape[1:] shape = self.shape + val.shape[1:]
mat = torch.zeros(shape, device=self.device) mat = torch.zeros(shape, device=self.device, dtype=self.dtype)
mat[row, col] = val mat[row, col] = val
return mat return mat
......
import sys
import backend as F
import pytest
import torch
from dgl.mock_sparse2 import (add, create_from_coo, create_from_csc,
create_from_csr, diag)
# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_coo(val_shape):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
A = create_from_coo(row, col, val)
row = torch.tensor([1, 0]).to(ctx)
col = torch.tensor([0, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
B = create_from_coo(row, col, val, shape=A.shape)
sum1 = (A + B).dense()
sum2 = add(A, B).dense()
dense_sum = A.dense() + B.dense()
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csr(val_shape):
ctx = F.ctx()
indptr = torch.tensor([0, 1, 2, 3]).to(ctx)
indices = torch.tensor([3, 0, 2]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
A = create_from_csr(indptr, indices, val)
indptr = torch.tensor([0, 1, 2, 2]).to(ctx)
indices = torch.tensor([2, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
B = create_from_csr(indptr, indices, val, shape=A.shape)
sum1 = (A + B).dense()
sum2 = add(A, B).dense()
dense_sum = A.dense() + B.dense()
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csc(val_shape):
ctx = F.ctx()
indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)
indices = torch.tensor([1, 2, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
A = create_from_csc(indptr, indices, val)
indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)
indices = torch.tensor([1, 0]).to(ctx)
val = torch.randn(indices.shape + val_shape).to(ctx)
B = create_from_csc(indptr, indices, val, shape=A.shape)
sum1 = (A + B).dense()
sum2 = add(A, B).dense()
dense_sum = A.dense() + B.dense()
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_diag(val_shape):
ctx = F.ctx()
shape = (3, 4)
val_shape = (shape[0],) + val_shape
D1 = diag(torch.randn(val_shape).to(ctx), shape=shape)
D2 = diag(torch.randn(val_shape).to(ctx), shape=shape)
sum1 = (D1 + D2).dense()
sum2 = add(D1, D2).dense()
dense_sum = D1.dense() + D2.dense()
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_sparse_diag(val_shape):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
val = torch.randn(row.shape + val_shape).to(ctx)
A = create_from_coo(row, col, val)
shape = (3, 4)
val_shape = (shape[0],) + val_shape
D = diag(torch.randn(val_shape).to(ctx), shape=shape)
sum1 = (A + D).dense()
sum2 = (D + A).dense()
sum3 = add(A, D).dense()
sum4 = add(D, A).dense()
dense_sum = A.dense() + D.dense()
assert torch.allclose(dense_sum, sum1)
assert torch.allclose(dense_sum, sum2)
assert torch.allclose(dense_sum, sum3)
assert torch.allclose(dense_sum, sum4)
...@@ -19,7 +19,7 @@ def all_close_sparse(A, B): ...@@ -19,7 +19,7 @@ def all_close_sparse(A, B):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", [operator.add, operator.sub, operator.mul, operator.truediv] "op", [operator.sub, operator.mul, operator.truediv]
) )
def test_diag_op_diag(op): def test_diag_op_diag(op):
ctx = F.ctx() ctx = F.ctx()
......
...@@ -21,22 +21,6 @@ def all_close_sparse(A, row, col, val, shape): ...@@ -21,22 +21,6 @@ def all_close_sparse(A, row, col, val, shape):
assert A.shape == shape assert A.shape == shape
@pytest.mark.parametrize("op", [operator.add])
def test_sparse_op_sparse(op):
ctx = F.ctx()
rowA = torch.tensor([1, 0, 2, 7, 1]).to(ctx)
colA = torch.tensor([0, 49, 2, 1, 7]).to(ctx)
valA = torch.rand(len(rowA)).to(ctx)
A = create_from_coo(rowA, colA, valA, shape=(10, 50))
w = torch.rand(len(rowA)).to(ctx)
A1 = create_from_coo(rowA, colA, w, shape=(10, 50))
def _test():
all_close_sparse(op(A, A1), rowA, colA, valA + w, (10, 50))
_test()
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) @pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape): def test_pow(val_shape):
# A ** v # A ** v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment