Unverified Commit 999c6245 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support converson to/from torch sparse tensor. (#5388)

* [Sparse] Support converson to/from torch sparse tensor.

* Update
parent 11d12f3c
......@@ -79,15 +79,14 @@ class SparseMatrix : public torch::CustomClassHolder {
/**
* @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO.
* @param col Column indices of the COO.
* @param indices COO coordinates with shape (2, nnz).
* @param value Values of the sparse matrix.
* @param shape Shape of the sparse matrix.
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape);
/**
......@@ -153,6 +152,8 @@ class SparseMatrix : public torch::CustomClassHolder {
/** @return {row, col} tensors in the COO format. */
std::tuple<torch::Tensor, torch::Tensor> COOTensors();
/** @return Stacked row and col tensors in the COO format. */
torch::Tensor Indices();
/** @return {row, col, value_indices} tensors in the CSR format. */
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
CSRTensors();
......
......@@ -22,19 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B);
torch::Tensor sum;
{
// TODO(#5145) This is a workaround to reduce peak memory usage. It is no
// longer needed after we address #5145.
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
sum = torch_A + torch_B;
}
sum = sum.coalesce();
auto indices = sum.indices();
auto row = indices[0];
auto col = indices[1];
return SparseMatrix::FromCOO(row, col, sum.values(), A->shape());
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
auto sum = (torch_A + torch_B).coalesce();
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
}
} // namespace sparse
......
......@@ -27,6 +27,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("device", &SparseMatrix::device)
.def("shape", &SparseMatrix::shape)
.def("coo", &SparseMatrix::COOTensors)
.def("indices", &SparseMatrix::Indices)
.def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors)
.def("transpose", &SparseMatrix::Transpose)
......
......@@ -72,10 +72,10 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape) {
auto coo = std::make_shared<COO>(
COO{shape[0], shape[1], torch::stack({row, col}), false, false});
auto coo =
std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
return SparseMatrix::FromCOOPointer(coo, value, shape);
}
......@@ -138,10 +138,14 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr();
auto val = value();
return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
}
torch::Tensor SparseMatrix::Indices() {
auto coo = COOPtr();
return coo->indices;
}
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
auto csr = CSRPtr();
......
......@@ -17,10 +17,8 @@ namespace sparse {
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {
auto torch_coo = COOToTorchCOO(this->COOPtr(), this->value());
auto coalesced_coo = torch_coo.coalesce();
torch::Tensor indices = coalesced_coo.indices();
torch::Tensor row = indices[0];
torch::Tensor col = indices[1];
return SparseMatrix::FromCOO(row, col, coalesced_coo.values(), this->shape());
return SparseMatrix::FromCOO(
coalesced_coo.indices(), coalesced_coo.values(), this->shape());
}
bool SparseMatrix::HasDuplicate() {
......
......@@ -109,12 +109,35 @@ class SparseMatrix:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.coo()
(tensor([1, 2, 1]), tensor([2, 4, 3]))
"""
return self.c_sparse_matrix.coo()
def indices(self) -> torch.Tensor:
r"""Returns the coordinate list (COO) representation in one tensor with
shape ``(2, nnz)``.
See `COO in Wikipedia <https://en.wikipedia.org/wiki/
Sparse_matrix#Coordinate_list_(COO)>`_.
Returns
-------
torch.Tensor
Stacked COO tensor with shape ``(2, nnz)``.
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = dglsp.spmatrix(indices)
>>> A.indices()
tensor([[1, 2, 1],
[2, 4, 3]])
"""
return self.c_sparse_matrix.indices()
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Returns the compressed sparse row (CSR) representation of the sparse
matrix.
......@@ -140,7 +163,7 @@ class SparseMatrix:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.csr()
(tensor([0, 0, 2, 3]), tensor([2, 3, 4]), tensor([0, 2, 1]))
"""
......@@ -171,7 +194,7 @@ class SparseMatrix:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.csc()
(tensor([0, 0, 0, 1, 2, 3]), tensor([1, 1, 2]), tensor([0, 2, 1]))
"""
......@@ -521,7 +544,18 @@ def spmatrix(
[3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,))
"""
return from_coo(indices[0], indices[1], val, shape)
if shape is None:
shape = (
torch.max(indices[0]).item() + 1,
torch.max(indices[1]).item() + 1,
)
if val is None:
val = torch.ones(indices.shape[1]).to(indices.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(indices, val, shape))
def from_coo(
......@@ -599,16 +633,8 @@ def from_coo(
[3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,))
"""
if shape is None:
shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)
if val is None:
val = torch.ones(row.shape[0]).to(row.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape))
assert row.shape[0] == col.shape[0]
return spmatrix(torch.stack([row, col]), val, shape)
def from_csr(
......@@ -833,6 +859,171 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:
"""Creates a sparse matrix from a torch sparse tensor, which can have coo,
csr, or csc layout.
Parameters
----------
torch_sparse_tensor : torch.Tensor
Torch sparse tensor
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> torch_coo = torch.sparse_coo_tensor(indices, val)
>>> dglsp.from_torch_sparse(torch_coo)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
"""
assert torch_sparse_tensor.layout in (
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
), (
f"Cannot convert Pytorch sparse tensor with layout "
f"{torch_sparse_tensor.layout} to DGL sparse."
)
if torch_sparse_tensor.layout == torch.sparse_coo:
# Use ._indices() and ._values() to access uncoalesced indices and
# values.
return spmatrix(
torch_sparse_tensor._indices(),
torch_sparse_tensor._values(),
torch_sparse_tensor.shape[:2],
)
elif torch_sparse_tensor.layout == torch.sparse_csr:
return from_csr(
torch_sparse_tensor.crow_indices(),
torch_sparse_tensor.col_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
else:
return from_csc(
torch_sparse_tensor.ccol_indices(),
torch_sparse_tensor.row_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
def to_torch_sparse_coo(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse coo tensor from a sparse matrix.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
torch tensor with torch.sparse_coo layout
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_coo(spmat)
tensor(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
size=(3, 5), nnz=3, layout=torch.sparse_coo)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
return torch.sparse_coo_tensor(spmat.indices(), spmat.val, shape)
def to_torch_sparse_csr(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csr tensor from a sparse matrix.
Note that converting a sparse matrix to torch csr tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csr layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csr(spmat)
tensor(crow_indices=tensor([0, 0, 2, 3]),
col_indices=tensor([2, 3, 4]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csr)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csr()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csr_tensor(indptr, indices, val, shape)
def to_torch_sparse_csc(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csc tensor from a sparse matrix.
Note that converting a sparse matrix to torch csc tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csc layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csc(spmat)
tensor(ccol_indices=tensor([0, 0, 0, 1, 2, 3]),
row_indices=tensor([1, 1, 2]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csc)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csc()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csc_tensor(indptr, indices, val, shape)
def _sparse_matrix_str(spmat: SparseMatrix) -> str:
"""Internal function for converting a sparse matrix to string
representation.
......
......@@ -5,7 +5,16 @@ import backend as F
import pytest
import torch
from dgl.sparse import from_coo, from_csc, from_csr, val_like
from dgl.sparse import (
from_coo,
from_csc,
from_csr,
from_torch_sparse,
to_torch_sparse_coo,
to_torch_sparse_csc,
to_torch_sparse_csr,
val_like,
)
@pytest.mark.parametrize("dense_dim", [None, 4])
......@@ -502,3 +511,98 @@ def test_sparse_matrix_transpose(dense_dim, row, col, extra_shape):
assert torch.allclose(mat_val, val)
assert torch.allclose(mat_row, col)
assert torch.allclose(mat_col, row)
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("nz_dim", [None, 2])
@pytest.mark.parametrize("shape", [(5, 5), (6, 7)])
def test_torch_sparse_coo_conversion(row, col, nz_dim, shape):
dev = F.ctx()
row = torch.tensor(row).to(dev)
col = torch.tensor(col).to(dev)
indices = torch.stack([row, col])
torch_sparse_shape = shape
val_shape = (row.shape[0],)
if nz_dim is not None:
torch_sparse_shape += (nz_dim,)
val_shape += (nz_dim,)
val = torch.randn(val_shape).to(dev)
torch_sparse_coo = torch.sparse_coo_tensor(indices, val, torch_sparse_shape)
spmat = from_torch_sparse(torch_sparse_coo)
def _assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo):
assert torch_sparse_coo.layout == torch.sparse_coo
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert (
spmat.indices().data_ptr() == torch_sparse_coo._indices().data_ptr()
)
assert spmat.val.data_ptr() == torch_sparse_coo._values().data_ptr()
assert spmat.shape == torch_sparse_coo.shape[:2]
_assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)
torch_sparse_coo = to_torch_sparse_coo(spmat)
_assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)
@pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("shape", [(3, 5), (3, 7)])
def test_torch_sparse_csr_conversion(indptr, indices, shape):
dev = F.ctx()
indptr = torch.tensor(indptr).to(dev)
indices = torch.tensor(indices).to(dev)
torch_sparse_shape = shape
val_shape = (indices.shape[0],)
val = torch.randn(val_shape).to(dev)
torch_sparse_csr = torch.sparse_csr_tensor(
indptr, indices, val, torch_sparse_shape
)
spmat = from_torch_sparse(torch_sparse_csr)
def _assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr):
indptr, indices, value_indices = spmat.csr()
assert torch_sparse_csr.layout == torch.sparse_csr
assert value_indices is None
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert indptr.data_ptr() == torch_sparse_csr.crow_indices().data_ptr()
assert indices.data_ptr() == torch_sparse_csr.col_indices().data_ptr()
assert spmat.val.data_ptr() == torch_sparse_csr.values().data_ptr()
assert spmat.shape == torch_sparse_csr.shape[:2]
_assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)
torch_sparse_csr = to_torch_sparse_csr(spmat)
_assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)
@pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("shape", [(8, 3), (5, 3)])
def test_torch_sparse_csc_conversion(indptr, indices, shape):
dev = F.ctx()
indptr = torch.tensor(indptr).to(dev)
indices = torch.tensor(indices).to(dev)
torch_sparse_shape = shape
val_shape = (indices.shape[0],)
val = torch.randn(val_shape).to(dev)
torch_sparse_csc = torch.sparse_csc_tensor(
indptr, indices, val, torch_sparse_shape
)
spmat = from_torch_sparse(torch_sparse_csc)
def _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc):
indptr, indices, value_indices = spmat.csc()
assert torch_sparse_csc.layout == torch.sparse_csc
assert value_indices is None
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert indptr.data_ptr() == torch_sparse_csc.ccol_indices().data_ptr()
assert indices.data_ptr() == torch_sparse_csc.row_indices().data_ptr()
assert spmat.val.data_ptr() == torch_sparse_csc.values().data_ptr()
assert spmat.shape == torch_sparse_csc.shape[:2]
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
torch_sparse_csc = to_torch_sparse_csc(spmat)
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment