"src/graph/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "cbd55eb12f8fba2ee848de64ce12cfbe03138e2b"
Unverified Commit 14429df6 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support BSpMM and BSDDMM (#5079)

* [Sparse] Support BSpMM and BSDDMM

* Update SpMM and SDDMM error messages

* Use TORCH_CHECK

* Update error string
parent bda6a816
...@@ -14,10 +14,15 @@ namespace sparse { ...@@ -14,10 +14,15 @@ namespace sparse {
/** /**
* @brief Perform a sampled matrix multiplication of a sparse matrix and two * @brief Perform a sampled matrix multiplication of a sparse matrix and two
* dense matrices. It calculates `(mat1 @ mat2) * sparse_mat`. If the sparse * dense matrices. It calculates `sparse_mat * (mat1 @ mat2)`. The SDDMM can be
* matrix has shape (n, m), `mat1` and `mat2` must have shapes of `(n, k)` and * batched, where the batch dimension is the last dimension for all input
* `(k, m)` or `(n,)` and `(m,)` respectively. And the returned tensor has shape * matrices.
* `(sparse_matrix->nnz(),)`. *
* There are four cases for the input and output matrix shapes:
* (1) (n, m), (n, k), (k, m), and (n, m);
* (2) (n, m), (n,), and (m,), and (n, m);
* (3) (n, m, b), (n, k, b), (k, m, b), and (n, m, b);
* (4) (n, m), (n, k, b), (k, m, b), and (n, m, b);
* *
* This function supports autograd for `mat1` and `mat2` but does not support * This function supports autograd for `mat1` and `mat2` but does not support
* high order gradient. * high order gradient.
......
...@@ -14,9 +14,13 @@ namespace sparse { ...@@ -14,9 +14,13 @@ namespace sparse {
/** /**
* @brief Perform a matrix multiplication of the sparse matrix and dense * @brief Perform a matrix multiplication of the sparse matrix and dense
* matrix. The sparse matrix must have 1-dimensional values. If the sparse * matrix. The SpMM can be batched, where the batch dimension is the last
* matrix has shape (n, m), the dense matrix must have shape (m, k) or (m,), and * dimension for both sparse and dense matrices.
* the returned dense matrix has shape (n, k) or (n,). *
* There are three cases for sparse, dense, and output matrix shapes:
* (1) (n, m), (m, k), and (n, k);
* (2) (n, m), (m,), and (n,);
* (3) (n, m, b), (m, k, b), and (n, k, b).
* *
* This function supports autograd for both the sparse and dense matrix but does * This function supports autograd for both the sparse and dense matrix but does
* not support higher order gradient. * not support higher order gradient.
......
...@@ -24,7 +24,11 @@ torch::Tensor SpMMNoAutoGrad( ...@@ -24,7 +24,11 @@ torch::Tensor SpMMNoAutoGrad(
const std::string reduce = "sum"; const std::string reduce = "sum";
const int64_t out_row = const int64_t out_row =
transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0]; transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0];
const std::vector<int64_t> shape = {out_row, dense_mat.size(1)}; std::vector<int64_t> shape = {out_row, dense_mat.size(1)};
// Batched SpMM
if (sparse_val.dim() >= 2) {
shape = {out_row, dense_mat.size(1), sparse_val.size(1)};
}
auto ret = torch::zeros(shape, dense_mat.options()); auto ret = torch::zeros(shape, dense_mat.options());
auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val); auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
...@@ -74,7 +78,15 @@ torch::Tensor SDDMMNoAutoGrad( ...@@ -74,7 +78,15 @@ torch::Tensor SDDMMNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
torch::Tensor mat2_tr) { torch::Tensor mat2_tr) {
const int64_t out_row = sparse_mat->nnz(); const int64_t out_row = sparse_mat->nnz();
const std::vector<int64_t> shape({out_row}); std::vector<int64_t> shape({out_row});
// Batched SDDMM
if (mat1.dim() >= 3) {
shape.push_back(mat1.size(2));
// (N, K, B) -> (N, B, K)
mat1 = mat1.transpose(1, 2).contiguous();
// (M, K, B) -> (M, B, K)
mat2_tr = mat2_tr.transpose(1, 2).contiguous();
}
auto ret = torch::zeros(shape, mat1.options()); auto ret = torch::zeros(shape, mat1.options());
const std::string op = "dot"; const std::string op = "dot";
auto dgl_mat1 = TorchTensorToDGLArray(mat1); auto dgl_mat1 = TorchTensorToDGLArray(mat1);
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <sparse/spmm.h> #include <sparse/spmm.h>
#include <torch/script.h> #include <torch/script.h>
#include <sstream>
#include "./matmul.h" #include "./matmul.h"
#include "./utils.h" #include "./utils.h"
...@@ -27,27 +29,40 @@ class SDDMMAutoGrad : public Function<SDDMMAutoGrad> { ...@@ -27,27 +29,40 @@ class SDDMMAutoGrad : public Function<SDDMMAutoGrad> {
void _SDDMMSanityCheck( void _SDDMMSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
torch::Tensor mat2) { torch::Tensor mat2) {
const int64_t mat1_dim = mat1.dim(); bool shape_check = true;
const int64_t mat2_dim = mat2.dim(); shape_check &= mat1.dim() == mat2.dim();
CHECK_EQ(mat1_dim, mat2_dim) shape_check &= mat1.dim() <= 3;
<< "SDDMM: the two dense matrices should have the same dimensions."; shape_check &= sparse_mat->shape()[0] == mat1.size(0);
CHECK_LE(mat1_dim, 2) if (mat1.dim() == 3) {
<< "SDDMM: the first dense matrix should have at most two dimensions."; shape_check &= sparse_mat->shape()[1] == mat2.size(1);
CHECK_EQ(sparse_mat->shape()[0], mat1.size(0)) shape_check &= mat1.size(2) == mat2.size(2);
<< "SDDMM: the first dense matrix should have the same first dimension " if (sparse_mat->value().dim() > 1) {
"as the sparse matrix"; shape_check &= sparse_mat->value().size(1) == mat1.size(2);
CHECK_EQ(sparse_mat->shape()[1], mat2.size(mat2_dim - 1)) }
<< "SDDMM: the second dense matrix should have the same last dimension " } else {
"as the sparse matrix"; shape_check &= sparse_mat->shape()[1] == mat2.size(mat2.dim() - 1);
if (mat1_dim == 2) { }
CHECK_EQ(mat1.size(1), mat2.size(0)) if (mat1.dim() >= 2) {
<< "SDDMM: the second dimension of the first dense matrix should be " shape_check &= mat1.size(1) == mat2.size(0);
"equal to the first dimension of the second dense matrix.";
} }
CHECK_EQ(mat1.dtype(), mat2.dtype()) if (!shape_check) {
<< "SDDMM: the two dense matrices should have the same dtype."; std::stringstream error;
CHECK_EQ(mat1.device(), mat2.device()) error << "SDDMM: Invalid input shapes. sparse_mat: "
<< "SDDMM: the two dense matrices should on the same device."; << c10::IntArrayRef(sparse_mat->shape())
<< ", sparse_val: " << sparse_mat->value().sizes()
<< ", mat1: " << mat1.sizes() << ", mat2: " << mat2.sizes()
<< ". Valid input shapes (sparse_mat, mat1, mat2) are: (1) (n, m), "
"(n, k), and (k, m); (2) (n, m), (n,), and (m,); (3) (n, m, b), "
"(n, k, b) and (k, m, b); (4) "
"(n, m), (n, k, b), and (k, m, b).";
TORCH_CHECK(false, error.str());
}
TORCH_CHECK(
mat1.dtype() == mat2.dtype(),
"SDDMM: the two dense matrices should have the same dtype.");
TORCH_CHECK(
mat1.device() == mat2.device(),
"SDDMM: the two dense matrices should on the same device.");
} }
torch::Tensor SDDMMAutoGrad::forward( torch::Tensor SDDMMAutoGrad::forward(
...@@ -101,7 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SDDMM( ...@@ -101,7 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SDDMM(
} }
_SDDMMSanityCheck(sparse_mat, mat1, mat2); _SDDMMSanityCheck(sparse_mat, mat1, mat2);
auto val = SDDMMAutoGrad::apply(sparse_mat, mat1, mat2); auto val = SDDMMAutoGrad::apply(sparse_mat, mat1, mat2);
val = val * sparse_mat->value(); auto sparse_val = sparse_mat->value();
// Broadcast the sparse value in batched SDDMM.
if (sparse_val.dim() < val.dim()) {
sparse_val = sparse_val.unsqueeze(-1);
}
val = val * sparse_val;
return CreateValLike(sparse_mat, val); return CreateValLike(sparse_mat, val);
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <sparse/spmm.h> #include <sparse/spmm.h>
#include <torch/script.h> #include <torch/script.h>
#include <sstream>
#include "./matmul.h" #include "./matmul.h"
#include "./utils.h" #include "./utils.h"
...@@ -32,23 +34,34 @@ void _SpMMSanityCheck( ...@@ -32,23 +34,34 @@ void _SpMMSanityCheck(
const auto& sparse_mat_shape = sparse_mat->shape(); const auto& sparse_mat_shape = sparse_mat->shape();
auto val_shape = sparse_val.sizes(); auto val_shape = sparse_val.sizes();
auto dense_shape = dense_mat.sizes(); auto dense_shape = dense_mat.sizes();
CHECK_EQ(sparse_mat_shape[1], dense_shape[0]) bool shape_check = true;
<< "SpMM: the second dimension of the sparse matrix should be equal to " shape_check &= sparse_mat_shape[1] == dense_shape[0];
"the first dimension of the dense matrix."; shape_check &= val_shape.size() <= 2;
CHECK_EQ(val_shape.size(), 1) shape_check &= val_shape[0] == sparse_mat->nnz();
<< "SpMM: the values tensor for SpMM can only be 1-dimensional."; shape_check &= dense_shape.size() <= 3;
CHECK_EQ(val_shape[0], sparse_mat->nnz()) if (dense_shape.size() == 3 || val_shape.size() == 2) {
<< "SpMM: the value shape does not match nnz of the sparse matrix."; shape_check &= dense_shape.size() == val_shape.size() + 1;
CHECK_LE(dense_shape.size(), 2) shape_check &= dense_shape[2] == val_shape[1];
<< "SpMM: the dense matrix can have at most two dimensions."; }
CHECK_EQ(sparse_val.dtype(), dense_mat.dtype()) if (!shape_check) {
<< "SpMM: the non-zero values does not have the same dtype as the dense " std::stringstream error;
"matrix."; error << "SpMM: Invalid input shapes. sparse_mat: "
CHECK( << c10::IntArrayRef(sparse_mat->shape())
<< ", sparse_val: " << sparse_mat->value().sizes()
<< ", dense_mat: " << dense_mat.sizes()
<< ". Valid input shapes (sparse_mat, dense_mat) are: (1) (n, m) and "
"(m, k); (2) (n, m) and (m,); (3) (n, m, b) and (m, k, b).";
TORCH_CHECK(false, error.str());
}
TORCH_CHECK(
sparse_val.dtype() == dense_mat.dtype(),
"SpMM: the non-zero values does not have the same dtype as the dense "
"matrix.");
TORCH_CHECK(
sparse_val.device() == sparse_mat->device() && sparse_val.device() == sparse_mat->device() &&
sparse_val.device() == dense_mat.device()) sparse_val.device() == dense_mat.device(),
<< "SpMM: sparse matrix, non-zero values and the dense matrix should be " "SpMM: sparse matrix, non-zero values and the dense matrix should be "
"on the same device."; "on the same device.");
} }
torch::Tensor SpMMAutoGrad::forward( torch::Tensor SpMMAutoGrad::forward(
......
...@@ -8,7 +8,7 @@ from .diag_matrix import diag, DiagMatrix ...@@ -8,7 +8,7 @@ from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
__all__ = ["spmm", "spspmm", "mm"] __all__ = ["spmm", "bspmm", "spspmm", "mm"]
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
...@@ -53,6 +53,44 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -53,6 +53,44 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X) return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)
def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix by batches.
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M, B) with values of shape (nnz)
X : torch.Tensor
Dense tensor of shape (M, F, B)
Returns
-------
torch.Tensor
The multiplication result of shape (N, F, B)
Examples
--------
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 2])
>>> val = torch.randn(len(row), 2)
>>> A = create_from_coo(row, col, val, shape=(3, 3))
>>> X = torch.randn(3, 3, 2)
>>> result = dgl.sparse.bspmm(A, X)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([3, 3, 2])
"""
assert isinstance(
A, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
assert isinstance(
X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}"
return spmm(A, X)
def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix """Internal function for multiplying a diagonal matrix by a diagonal matrix
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from .sparse_matrix import SparseMatrix from .sparse_matrix import SparseMatrix
__all__ = ["sddmm"] __all__ = ["sddmm", "bsddmm"]
def sddmm( def sddmm(
...@@ -47,7 +47,7 @@ def sddmm( ...@@ -47,7 +47,7 @@ def sddmm(
>>> A = create_from_coo(row, col, val, (3, 4)) >>> A = create_from_coo(row, col, val, (3, 4))
>>> mat1 = torch.randn(3, 5) >>> mat1 = torch.randn(3, 5)
>>> mat2 = torch.randn(5, 4) >>> mat2 = torch.randn(5, 4)
>>> dgl.mock_sparse.sddmm(A, mat1, mat2) >>> dgl.sparse.sddmm(A, mat1, mat2)
SparseMatrix(indices=tensor([[1, 1, 2], SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]), [2, 3, 3]]),
values=tensor([ 1.3097, -1.0977, 1.6953]), values=tensor([ 1.3097, -1.0977, 1.6953]),
...@@ -56,3 +56,55 @@ def sddmm( ...@@ -56,3 +56,55 @@ def sddmm(
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.sddmm(A.c_sparse_matrix, mat1, mat2) torch.ops.dgl_sparse.sddmm(A.c_sparse_matrix, mat1, mat2)
) )
def bsddmm(
A: SparseMatrix, mat1: torch.Tensor, mat2: torch.Tensor
) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.
``sddmm`` multiplies two dense matrices :attr:`mat1` and :attr:`mat2`
at the nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A`
is not considered during the computation.
Mathematically ``sddmm`` is formulated as:
.. math::
out = (mat1 @ mat2) * A
The batch dimension is the last dimension for input matrices. In particular,
if the sparse matrix has scalar non-zero values, it will be broadcasted
for bsddmm.
Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(M, N)`` or ``(M, N, B)``.
mat1 : Tensor
Dense matrix of shape ``(M, K, B)``
mat2 : Tensor
Dense matrix of shape ``(K, N, B)``
Returns
-------
SparseMatrix
Sparse matrix of shape ``(M, N, B)``.
Examples
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float()
>>> A = create_from_coo(row, col, val, (3, 4))
>>> mat1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float()
>>> mat2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float()
>>> dgl.sparse.bsddmm(A, mat1, mat2)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]),
values=tensor([[1560., 1735.],
[3400., 3770.],
[8400., 9105.]]),
shape=(3, 4), nnz=3)
"""
return sddmm(A, mat1, mat2)
...@@ -4,10 +4,11 @@ import backend as F ...@@ -4,10 +4,11 @@ import backend as F
import pytest import pytest
import torch import torch
from dgl.mock_sparse2 import create_from_coo, val_like from dgl.mock_sparse2 import bspmm, create_from_coo, val_like
from .utils import ( from .utils import (
clone_detach_and_grad, clone_detach_and_grad,
dense_mask,
rand_coo, rand_coo,
rand_csc, rand_csc,
rand_csr, rand_csr,
...@@ -53,6 +54,34 @@ def test_spmm(create_func, shape, nnz, out_dim): ...@@ -53,6 +54,34 @@ def test_spmm(create_func, shape, nnz, out_dim):
) )
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_bspmm(create_func, shape, nnz):
dev = F.ctx()
A = create_func(shape, nnz, dev, 2)
X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)
sparse_result = bspmm(A, X)
grad = torch.randn_like(sparse_result)
sparse_result.backward(grad)
XX = clone_detach_and_grad(X)
torch_A = A.dense().clone().detach().requires_grad_()
torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1)
torch_result.backward(grad.permute(2, 0, 1))
assert torch.allclose(
sparse_result.permute(2, 0, 1), torch_result, atol=1e-05
)
assert torch.allclose(X.grad, XX.grad, atol=1e-05)
assert torch.allclose(
dense_mask(torch_A.grad, A),
sparse_matrix_to_dense(val_like(A, A.val.grad)),
atol=1e-05,
)
@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc]) @pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc]) @pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)]) @pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)])
......
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import backend as F import backend as F
import pytest import pytest
import torch import torch
from dgl.mock_sparse2 import sddmm from dgl.mock_sparse2 import bsddmm, sddmm
from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr
...@@ -48,3 +48,36 @@ def test_sddmm(create_func, shape, nnz, hidden): ...@@ -48,3 +48,36 @@ def test_sddmm(create_func, shape, nnz, hidden):
assert torch.allclose(dense_C.grad, C.grad, atol=1e-05) assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad, atol=1e-05) assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05) assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
@pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("nz_dim", [2, 10])
def test_bsddmm(create_func, shape, nnz, nz_dim):
dev = F.ctx()
hidden = 2
A = create_func(shape, nnz, dev, nz_dim)
B = torch.rand(shape[0], hidden, nz_dim, requires_grad=True, device=dev)
C = torch.rand(hidden, shape[1], nz_dim, requires_grad=True, device=dev)
A_val_clone = clone_detach_and_grad(A.val)
dense_B = clone_detach_and_grad(B)
dense_C = clone_detach_and_grad(C)
sparse_result = bsddmm(A, B, C)
grad = torch.rand_like(sparse_result.val)
sparse_result.val.backward(grad)
dense_result = dense_B.permute(2, 0, 1) @ dense_C.permute(2, 0, 1)
dense_result = dense_result.permute(1, 2, 0)
row, col = A.coo()
dense_val = dense_result[row, col] * A_val_clone
dense_val.backward(grad)
assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
...@@ -18,23 +18,29 @@ def clone_detach_and_grad(t): ...@@ -18,23 +18,29 @@ def clone_detach_and_grad(t):
return t return t
def rand_coo(shape, nnz, dev): def rand_coo(shape, nnz, dev, nz_dim=None):
# Create a sparse matrix without duplicate entries. # Create a sparse matrix without duplicate entries.
nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long() nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor") row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1] col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True) if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
return create_from_coo(row, col, val, shape) return create_from_coo(row, col, val, shape)
def rand_csr(shape, nnz, dev): def rand_csr(shape, nnz, dev, nz_dim=None):
# Create a sparse matrix without duplicate entries. # Create a sparse matrix without duplicate entries.
nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long() nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor") row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1] col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True) if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64) indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
for r in row.tolist(): for r in row.tolist():
indptr[r + 1] += 1 indptr[r + 1] += 1
...@@ -44,13 +50,16 @@ def rand_csr(shape, nnz, dev): ...@@ -44,13 +50,16 @@ def rand_csr(shape, nnz, dev):
return create_from_csr(indptr, indices, val, shape=shape) return create_from_csr(indptr, indices, val, shape=shape)
def rand_csc(shape, nnz, dev): def rand_csc(shape, nnz, dev, nz_dim=None):
# Create a sparse matrix without duplicate entries. # Create a sparse matrix without duplicate entries.
nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False) nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
nnzid = torch.tensor(nnzid, device=dev).long() nnzid = torch.tensor(nnzid, device=dev).long()
row = torch.div(nnzid, shape[1], rounding_mode="floor") row = torch.div(nnzid, shape[1], rounding_mode="floor")
col = nnzid % shape[1] col = nnzid % shape[1]
val = torch.randn(nnz, device=dev, requires_grad=True) if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64) indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
for c in col.tolist(): for c in col.tolist():
indptr[c + 1] += 1 indptr[c + 1] += 1
...@@ -114,3 +123,11 @@ def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None): ...@@ -114,3 +123,11 @@ def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):
ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce() ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()
ret.requires_grad_() ret.requires_grad_()
return ret return ret
def dense_mask(dense, sparse):
ret = torch.zeros_like(dense)
row, col = sparse.coo()
for r, c in zip(row, col):
ret[r, c] = dense[r, c]
return ret
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment