Unverified Commit 0159c3c1 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add SpSpMM for sparse-sparse and sparse-diag matrix multiplication. (#5050)



* [Sparse] Add SpSpMM

* Update matmul interface

* address comments

* fix test utils to generate only coalesced matrices

* fix linter

* fix ut

* fix

* rm print
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent f0ce2bea
project(dgl_sparse C CXX)
# Find PyTorch cmake files and PyTorch versions with the python interpreter $TORCH_PYTHON_INTERPS # Find PyTorch cmake files and PyTorch versions with the python interpreter $TORCH_PYTHON_INTERPS
# ("python3" or "python" if empty) # ("python3" or "python" if empty)
if(NOT TORCH_PYTHON_INTERPS) if(NOT TORCH_PYTHON_INTERPS)
......
/**
* Copyright (c) 2022 by Contributors
* @file sparse/spspmm.h
* @brief DGL C++ SpSpMM operator.
*/
#ifndef SPARSE_SPSPMM_H_
#define SPARSE_SPSPMM_H_
#include <sparse/sparse_matrix.h>
#include <torch/script.h>
namespace dgl {
namespace sparse {
/**
* @brief Perform a sparse-sparse matrix multiplication on matrices with
* possibly different sparsities. The two sparse matrices must have
* 1-D values. If the first sparse matrix has shape (n, m), the second
* sparse matrix must have shape (m, k), and the returned sparse matrix has
* shape (n, k).
*
* This function supports autograd for both sparse matrices but does
* not support higher order gradient.
*
* @param lhs_mat The first sparse matrix of shape (n, m).
* @param rhs_mat The second sparse matrix of shape (m, k).
*
* @return Sparse matrix of shape (n, k).
*/
c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);
} // namespace sparse
} // namespace dgl
#endif // SPARSE_SPSPMM_H_
...@@ -98,5 +98,35 @@ torch::Tensor SDDMMNoAutoGrad( ...@@ -98,5 +98,35 @@ torch::Tensor SDDMMNoAutoGrad(
return ret; return ret;
} }
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
bool lhs_transpose, bool rhs_transpose) {
aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;
if (!lhs_transpose) {
lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());
} else {
lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());
}
if (!rhs_transpose) {
rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());
} else {
rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());
}
auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);
auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);
const int64_t ret_row =
lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];
const int64_t ret_col =
rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];
std::vector<int64_t> ret_shape({ret_row, ret_col});
aten::CSRMatrix ret_dgl_csr;
runtime::NDArray ret_val;
std::tie(ret_dgl_csr, ret_val) =
aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);
return SparseMatrix::FromCSR(
CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);
}
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
...@@ -53,6 +53,28 @@ torch::Tensor SDDMMNoAutoGrad( ...@@ -53,6 +53,28 @@ torch::Tensor SDDMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
torch::Tensor mat2_tr); torch::Tensor mat2_tr);
/**
* @brief Perform a sparse-sparse matrix multiplication with possibly different
* sparsities. The two sparse values must have 1-dimensional values. If the
* first sparse matrix has shape (n, m), the second sparse matrix must have
* shape (m, k), and the returned sparse matrix has shape (n, k).
*
* This function does not take care of autograd.
*
* @param lhs_mat The first sparse matrix of shape (n, m).
* @param lhs_val Sparse value for the first sparse matrix.
* @param rhs_mat The second sparse matrix of shape (m, k).
* @param rhs_val Sparse value for the second sparse matrix.
* @param lhs_transpose Whether the first matrix is transposed.
* @param rhs_transpose Whether the second matrix is transposed.
*
* @return Sparse matrix of shape (n, k).
*/
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,
bool lhs_transpose, bool rhs_transpose);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <sparse/sddmm.h> #include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h> #include <sparse/sparse_matrix.h>
#include <sparse/spmm.h> #include <sparse/spmm.h>
#include <sparse/spspmm.h>
#include <torch/custom_class.h> #include <torch/custom_class.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -40,7 +41,8 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -40,7 +41,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("sprod", &ReduceProd) .def("sprod", &ReduceProd)
.def("val_like", &CreateValLike) .def("val_like", &CreateValLike)
.def("spmm", &SpMM) .def("spmm", &SpMM)
.def("sddmm", &SDDMM); .def("sddmm", &SDDMM)
.def("spspmm", &SpSpMM);
} }
} // namespace sparse } // namespace sparse
......
...@@ -95,6 +95,8 @@ c10::intrusive_ptr<SparseMatrix> SDDMM( ...@@ -95,6 +95,8 @@ c10::intrusive_ptr<SparseMatrix> SDDMM(
torch::Tensor mat2) { torch::Tensor mat2) {
if (mat1.dim() == 1) { if (mat1.dim() == 1) {
mat1 = mat1.view({mat1.size(0), 1}); mat1 = mat1.view({mat1.size(0), 1});
}
if (mat2.dim() == 1) {
mat2 = mat2.view({1, mat2.size(0)}); mat2 = mat2.view({1, mat2.size(0)});
} }
_SDDMMSanityCheck(sparse_mat, mat1, mat2); _SDDMMSanityCheck(sparse_mat, mat1, mat2);
......
/**
* Copyright (c) 2022 by Contributors
* @file spspmm.cc
* @brief DGL C++ sparse SpSpMM operator implementation.
*/
#include <sparse/sddmm.h>
#include <sparse/sparse_matrix.h>
#include <sparse/spspmm.h>
#include <torch/script.h>
#include "./matmul.h"
#include "./utils.h"
namespace dgl {
namespace sparse {
using namespace torch::autograd;
class SpSpMMAutoGrad : public Function<SpSpMMAutoGrad> {
public:
static variable_list forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
void _SpSpMMSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
const auto& lhs_shape = lhs_mat->shape();
const auto& rhs_shape = rhs_mat->shape();
CHECK_EQ(lhs_shape[1], rhs_shape[0])
<< "SpSpMM: the second dim of lhs_mat should be equal to the first dim "
"of the second matrix";
CHECK_EQ(lhs_mat->value().dim(), 1)
<< "SpSpMM: the value shape of lhs_mat should be 1-D";
CHECK_EQ(rhs_mat->value().dim(), 1)
<< "SpSpMM: the value shape of rhs_mat should be 1-D";
CHECK_EQ(lhs_mat->device(), rhs_mat->device())
<< "SpSpMM: lhs_mat and rhs_mat should on the same device";
CHECK_EQ(lhs_mat->dtype(), rhs_mat->dtype())
<< "SpSpMM: lhs_mat and rhs_mat should have the same dtype";
}
// Mask select value of `mat` by `sub_mat`.
torch::Tensor _CSRMask(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value,
const c10::intrusive_ptr<SparseMatrix>& sub_mat) {
auto csr = CSRToOldDGLCSR(mat->CSRPtr());
auto val = TorchTensorToDGLArray(value);
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
runtime::NDArray ret;
ATEN_FLOAT_TYPE_SWITCH(val->dtype, DType, "Value Type", {
ret = aten::CSRGetData<DType>(csr, row, col, val, 0.);
});
return DGLArrayToTorchTensor(ret);
}
variable_list SpSpMMAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val) {
auto ret_mat =
SpSpMMNoAutoGrad(lhs_mat, lhs_val, rhs_mat, rhs_val, false, false);
ctx->saved_data["lhs_mat"] = lhs_mat;
ctx->saved_data["rhs_mat"] = rhs_mat;
ctx->saved_data["ret_mat"] = ret_mat;
ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
ctx->save_for_backward({lhs_val, rhs_val});
auto csr = ret_mat->CSRPtr();
auto val = ret_mat->value();
CHECK(!csr->value_indices.has_value());
return {csr->indptr, csr->indices, val};
}
tensor_list SpSpMMAutoGrad::backward(
AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto lhs_val = saved[0];
auto rhs_val = saved[1];
auto output_grad = grad_outputs[2];
auto lhs_mat = ctx->saved_data["lhs_mat"].toCustomClass<SparseMatrix>();
auto rhs_mat = ctx->saved_data["rhs_mat"].toCustomClass<SparseMatrix>();
auto ret_mat = ctx->saved_data["ret_mat"].toCustomClass<SparseMatrix>();
torch::Tensor lhs_val_grad, rhs_val_grad;
if (ctx->saved_data["lhs_require_grad"].toBool()) {
// A @ B = C -> dA = dC @ (B^T)
auto lhs_mat_grad =
SpSpMMNoAutoGrad(ret_mat, output_grad, rhs_mat, rhs_val, false, true);
lhs_val_grad = _CSRMask(lhs_mat_grad, lhs_mat_grad->value(), lhs_mat);
}
if (ctx->saved_data["rhs_require_grad"].toBool()) {
// A @ B = C -> dB = (A^T) @ dC
auto rhs_mat_grad =
SpSpMMNoAutoGrad(lhs_mat, lhs_val, ret_mat, output_grad, true, false);
rhs_val_grad = _CSRMask(rhs_mat_grad, rhs_mat_grad->value(), rhs_mat);
}
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}
c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
_SpSpMMSanityCheck(lhs_mat, rhs_mat);
auto results = SpSpMMAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});
auto indptr = results[0];
auto indices = results[1];
auto value = results[2];
return CreateFromCSR(indptr, indices, value, ret_shape);
}
} // namespace sparse
} // namespace dgl
...@@ -4,11 +4,11 @@ from typing import Union ...@@ -4,11 +4,11 @@ from typing import Union
import torch import torch
from .diag_matrix import DiagMatrix from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
__all__ = ["spmm"] __all__ = ["spmm", "spspmm", "mm"]
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
...@@ -53,51 +53,138 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -53,51 +53,138 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X) return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)
def mm_sp( def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
A1: SparseMatrix, A2: Union[torch.Tensor, SparseMatrix, DiagMatrix] """Internal function for multiplying a diagonal matrix by a diagonal matrix
) -> Union[torch.Tensor, SparseMatrix]:
"""Internal function for multiplying a sparse matrix by
a dense/sparse/diagonal matrix.
Parameters Parameters
---------- ----------
A1 : SparseMatrix A1 : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1) Matrix of shape (N, M), with values of shape (nnz1)
A2 : torch.Tensor, SparseMatrix, or DiagMatrix A2 : DiagMatrix
If A2 is a dense tensor, it can have shapes of (M, P) or (M, ). Matrix of shape (M, P), with values of shape (nnz2)
Otherwise it must have a shape of (M, P).
Returns Returns
------- -------
torch.Tensor or SparseMatrix DiagMatrix
The result of multiplication. The result of multiplication.
"""
M, N = A1.shape
N, P = A2.shape
common_diag_len = min(M, N, P)
new_diag_len = min(M, P)
diag_val = torch.zeros(new_diag_len)
diag_val[:common_diag_len] = (
A1.val[:common_diag_len] * A2.val[:common_diag_len]
)
return diag(diag_val.to(A1.device), (M, P))
def spspmm(
A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Multiply a sparse matrix by a sparse matrix. The non-zero values of the
two sparse matrices must be 1D.
Parameters
----------
A1 : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
A2 : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz)
Returns
-------
SparseMatrix or DiagMatrix
The result of multiplication. It is a DiagMatrix object if both matrices
are DiagMatrix objects. It is a SparseMatrix object otherwise.
Examples
--------
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A1 = create_from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> A2 = create_from_coo(row2, col2, val2)
>>> result = dgl.sparse.spspmm(A1, A2)
>>> print(result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(2, 3), nnz=5)
"""
assert isinstance(
A1, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}"
assert isinstance(
A2, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}"
if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix):
return _diag_diag_mm(A1, A2)
if isinstance(A1, DiagMatrix):
A1 = A1.as_sparse()
if isinstance(A2, DiagMatrix):
A2 = A2.as_sparse()
return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix)
)
def mm(
A1: Union[SparseMatrix, DiagMatrix],
A2: Union[torch.Tensor, SparseMatrix, DiagMatrix],
) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]:
"""Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix.
If an input is a SparseMatrix or DiagMatrix, its non-zero values should
be 1-D.
Parameters
----------
A1 : SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix,
it should have values of shape (nnz2).
Returns
-------
torch.Tensor or DiagMatrix or SparseMatrix
The result of multiplication of shape (N, P)
* It is a dense torch tensor if :attr:`A2` is so. * It is a dense torch tensor if :attr:`A2` is so.
* It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so.
* It is a SparseMatrix object otherwise. * It is a SparseMatrix object otherwise.
Examples Examples
-------- --------
>>> row = torch.tensor([0, 1, 1]) >>> val = torch.randn(3)
>>> col = torch.tensor([1, 0, 1]) >>> A1 = diag(val)
>>> val = torch.randn(len(row)) >>> A2 = torch.randn(3, 2)
>>> A1 = create_from_coo(row, col, val) >>> result = dgl.sparse.mm(A1, A2)
>>> A2 = torch.randn(2, 3)
>>> result = A1 @ A2
>>> print(type(result)) >>> print(type(result))
<class 'torch.Tensor'> <class 'torch.Tensor'>
>>> print(result.shape) >>> print(result.shape)
torch.Size([2, 3]) torch.Size([3, 2])
""" """
assert isinstance(
A1, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got {type(A1)}."
assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), ( assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), (
f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix object," f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f"got {type(A2)}" f"object, got {type(A2)}."
) )
if isinstance(A2, torch.Tensor): if isinstance(A2, torch.Tensor):
return spmm(A1, A2) return spmm(A1, A2)
else: if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix):
raise NotImplementedError return _diag_diag_mm(A1, A2)
return spspmm(A1, A2)
SparseMatrix.__matmul__ = mm_sp SparseMatrix.__matmul__ = mm
DiagMatrix.__matmul__ = mm
...@@ -11,8 +11,8 @@ def sddmm( ...@@ -11,8 +11,8 @@ def sddmm(
) -> SparseMatrix: ) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM). r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM).
``sddmm`` multiplies two dense matrices :attr:``mat1`` and :attr:``mat2`` ``sddmm`` multiplies two dense matrices :attr:`mat1` and :attr:`mat2`
at the nonzero locations of sparse matrix :attr:``A``. Values of :attr:``A`` at the nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A`
is not considered during the computation. is not considered during the computation.
Mathematically ``sddmm`` is formulated as: Mathematically ``sddmm`` is formulated as:
...@@ -20,19 +20,23 @@ def sddmm( ...@@ -20,19 +20,23 @@ def sddmm(
.. math:: .. math::
out = (mat1 @ mat2) * A out = (mat1 @ mat2) * A
In particular, :attr:`mat1` and :attr:`mat2` can be 1-D, then ``mat1 @
mat2`` becomes the out-product of the two vector (which results in a
matrix).
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix of shape `(M, N)`. Sparse matrix of shape ``(M, N)``.
mat1 : Tensor mat1 : Tensor
Dense matrix of shape `(M, K)` Dense matrix of shape ``(M, K)`` or ``(M,)``
mat2 : Tensor mat2 : Tensor
Dense matrix of shape `(K, N)` Dense matrix of shape ``(K, N)`` or ``(N,)``
Returns Returns
------- -------
SparseMatrix SparseMatrix
Sparse matrix of shape `(M, N)`. Sparse matrix of shape ``(M, N)``.
Examples Examples
-------- --------
......
...@@ -51,3 +51,41 @@ def test_spmm(create_func, shape, nnz, out_dim): ...@@ -51,3 +51,41 @@ def test_spmm(create_func, shape, nnz, out_dim):
sparse_matrix_to_dense(val_like(A, A.val.grad)), sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05, atol=1e-05,
) )
@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)])
@pytest.mark.parametrize("shape_k", [3, 4])
@pytest.mark.parametrize("nnz1", [1, 10])
@pytest.mark.parametrize("nnz2", [1, 10])
def test_sparse_sparse_mm(
create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2
):
dev = F.ctx()
shape1 = shape_n_m
shape2 = (shape_n_m[1], shape_k)
A1 = create_func1(shape1, nnz1, dev)
A2 = create_func2(shape2, nnz2, dev)
A3 = A1 @ A2
grad = torch.randn_like(A3.val)
A3.val.backward(grad)
torch_A1 = sparse_matrix_to_torch_sparse(A1)
torch_A2 = sparse_matrix_to_torch_sparse(A2)
torch_A3 = torch.sparse.mm(torch_A1, torch_A2)
torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)
torch_A3.backward(torch_A3_grad)
with torch.no_grad():
assert torch.allclose(A3.dense(), torch_A3.to_dense(), atol=1e-05)
assert torch.allclose(
val_like(A1, A1.val.grad).dense(),
torch_A1.grad.to_dense(),
atol=1e-05,
)
assert torch.allclose(
val_like(A2, A2.val.grad).dense(),
torch_A2.grad.to_dense(),
atol=1e-05,
)
...@@ -13,7 +13,7 @@ if not sys.platform.startswith("linux"): ...@@ -13,7 +13,7 @@ if not sys.platform.startswith("linux"):
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc]) @pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 3), (5, 2)]) @pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
@pytest.mark.parametrize("nnz", [2, 10]) @pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("hidden", [1, 5]) @pytest.mark.parametrize("hidden", [1, 5])
def test_sddmm(create_func, shape, nnz, hidden): def test_sddmm(create_func, shape, nnz, hidden):
......
import numpy as np
import torch import torch
from dgl.mock_sparse2 import ( from dgl.mock_sparse2 import (
create_from_coo, create_from_coo,
create_from_csc, create_from_csc,
...@@ -6,6 +8,9 @@ from dgl.mock_sparse2 import ( ...@@ -6,6 +8,9 @@ from dgl.mock_sparse2 import (
SparseMatrix, SparseMatrix,
) )
np.random.seed(42)
torch.random.manual_seed(42)
def clone_detach_and_grad(t): def clone_detach_and_grad(t):
t = t.clone().detach() t = t.clone().detach()
...@@ -14,33 +19,80 @@ def clone_detach_and_grad(t): ...@@ -14,33 +19,80 @@ def clone_detach_and_grad(t):
def rand_coo(shape, nnz, dev): def rand_coo(shape, nnz, dev):
row = torch.randint(0, shape[0], (nnz,), device=dev) # Create a sparse matrix without duplicate entries.
col = torch.randint(0, shape[1], (nnz,), device=dev) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True) val = torch.randn(nnz, device=dev, requires_grad=True)
return create_from_coo(row, col, val, shape) return create_from_coo(row, col, val, shape)
def rand_csr(shape, nnz, dev): def rand_csr(shape, nnz, dev):
row = torch.randint(0, shape[0], (nnz,), device=dev) # Create a sparse matrix without duplicate entries.
col = torch.randint(0, shape[1], (nnz,), device=dev) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True) val = torch.randn(nnz, device=dev, requires_grad=True)
indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64) indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
for r in row.tolist(): for r in row.tolist():
indptr[r + 1] += 1 indptr[r + 1] += 1
indptr = torch.cumsum(indptr, 0) indptr = torch.cumsum(indptr, 0)
indices = col row_sorted, row_sorted_idx = torch.sort(row)
indices = col[row_sorted_idx]
return create_from_csr(indptr, indices, val, shape=shape) return create_from_csr(indptr, indices, val, shape=shape)
def rand_csc(shape, nnz, dev): def rand_csc(shape, nnz, dev):
row = torch.randint(0, shape[0], (nnz,), device=dev) # Create a sparse matrix without duplicate entries.
col = torch.randint(0, shape[1], (nnz,), device=dev) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True)
indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
for c in col.tolist():
indptr[c + 1] += 1
indptr = torch.cumsum(indptr, 0)
col_sorted, col_sorted_idx = torch.sort(col)
indices = row[col_sorted_idx]
return create_from_csc(indptr, indices, val, shape=shape)
def rand_coo_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev)
col = torch.randint(shape[1], (nnz,), device=dev)
val = torch.randn(nnz, device=dev, requires_grad=True)
return create_from_coo(row, col, val, shape)
def rand_csr_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev)
col = torch.randint(shape[1], (nnz,), device=dev)
val = torch.randn(nnz, device=dev, requires_grad=True)
indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
for r in row.tolist():
indptr[r + 1] += 1
indptr = torch.cumsum(indptr, 0)
row_sorted, row_sorted_idx = torch.sort(row)
indices = col[row_sorted_idx]
return create_from_csr(indptr, indices, val, shape=shape)
def rand_csc_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev)
col = torch.randint(shape[1], (nnz,), device=dev)
val = torch.randn(nnz, device=dev, requires_grad=True) val = torch.randn(nnz, device=dev, requires_grad=True)
indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64) indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
for c in col.tolist(): for c in col.tolist():
indptr[c + 1] += 1 indptr[c + 1] += 1
indptr = torch.cumsum(indptr, 0) indptr = torch.cumsum(indptr, 0)
indices = row col_sorted, col_sorted_idx = torch.sort(col)
indices = row[col_sorted_idx]
return create_from_csc(indptr, indices, val, shape=shape) return create_from_csc(indptr, indices, val, shape=shape)
...@@ -50,11 +102,13 @@ def sparse_matrix_to_dense(A: SparseMatrix): ...@@ -50,11 +102,13 @@ def sparse_matrix_to_dense(A: SparseMatrix):
return dense return dense
def sparse_matrix_to_torch_sparse(A: SparseMatrix): def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):
row, col = A.coo() row, col = A.coo()
edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0) edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
shape = A.shape shape = A.shape
val = A.val.clone().detach() if val is None:
val = A.val
val = val.clone().detach()
if len(A.val.shape) > 1: if len(A.val.shape) > 1:
shape += (A.val.shape[-1],) shape += (A.val.shape[-1],)
ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce() ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment