"src/vscode:/vscode.git/clone" did not exist on "64f49703af4644ed629b09e8ada8a6af1ded04e5"
Unverified Commit 999c6245 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support converson to/from torch sparse tensor. (#5388)

* [Sparse] Support converson to/from torch sparse tensor.

* Update
parent 11d12f3c
...@@ -79,15 +79,14 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -79,15 +79,14 @@ class SparseMatrix : public torch::CustomClassHolder {
/** /**
* @brief Create a SparseMatrix from tensors in COO format. * @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO. * @param indices COO coordinates with shape (2, nnz).
* @param col Column indices of the COO.
* @param value Values of the sparse matrix. * @param value Values of the sparse matrix.
* @param shape Shape of the sparse matrix. * @param shape Shape of the sparse matrix.
* *
* @return SparseMatrix * @return SparseMatrix
*/ */
static c10::intrusive_ptr<SparseMatrix> FromCOO( static c10::intrusive_ptr<SparseMatrix> FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/** /**
...@@ -153,6 +152,8 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -153,6 +152,8 @@ class SparseMatrix : public torch::CustomClassHolder {
/** @return {row, col} tensors in the COO format. */ /** @return {row, col} tensors in the COO format. */
std::tuple<torch::Tensor, torch::Tensor> COOTensors(); std::tuple<torch::Tensor, torch::Tensor> COOTensors();
/** @return Stacked row and col tensors in the COO format. */
torch::Tensor Indices();
/** @return {row, col, value_indices} tensors in the CSR format. */ /** @return {row, col, value_indices} tensors in the CSR format. */
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
CSRTensors(); CSRTensors();
......
...@@ -22,19 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd( ...@@ -22,19 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) { const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B); ElementwiseOpSanityCheck(A, B);
torch::Tensor sum;
{
// TODO(#5145) This is a workaround to reduce peak memory usage. It is no
// longer needed after we address #5145.
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value()); auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value()); auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
sum = torch_A + torch_B; auto sum = (torch_A + torch_B).coalesce();
} return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
sum = sum.coalesce();
auto indices = sum.indices();
auto row = indices[0];
auto col = indices[1];
return SparseMatrix::FromCOO(row, col, sum.values(), A->shape());
} }
} // namespace sparse } // namespace sparse
......
...@@ -27,6 +27,7 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -27,6 +27,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("device", &SparseMatrix::device) .def("device", &SparseMatrix::device)
.def("shape", &SparseMatrix::shape) .def("shape", &SparseMatrix::shape)
.def("coo", &SparseMatrix::COOTensors) .def("coo", &SparseMatrix::COOTensors)
.def("indices", &SparseMatrix::Indices)
.def("csr", &SparseMatrix::CSRTensors) .def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors) .def("csc", &SparseMatrix::CSCTensors)
.def("transpose", &SparseMatrix::Transpose) .def("transpose", &SparseMatrix::Transpose)
......
...@@ -72,10 +72,10 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer( ...@@ -72,10 +72,10 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
auto coo = std::make_shared<COO>( auto coo =
COO{shape[0], shape[1], torch::stack({row, col}), false, false}); std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
return SparseMatrix::FromCOOPointer(coo, value, shape); return SparseMatrix::FromCOOPointer(coo, value, shape);
} }
...@@ -138,10 +138,14 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() { ...@@ -138,10 +138,14 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() { std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr(); auto coo = COOPtr();
auto val = value();
return std::make_tuple(coo->indices.index({0}), coo->indices.index({1})); return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
} }
torch::Tensor SparseMatrix::Indices() {
auto coo = COOPtr();
return coo->indices;
}
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() { SparseMatrix::CSRTensors() {
auto csr = CSRPtr(); auto csr = CSRPtr();
......
...@@ -17,10 +17,8 @@ namespace sparse { ...@@ -17,10 +17,8 @@ namespace sparse {
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() { c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {
auto torch_coo = COOToTorchCOO(this->COOPtr(), this->value()); auto torch_coo = COOToTorchCOO(this->COOPtr(), this->value());
auto coalesced_coo = torch_coo.coalesce(); auto coalesced_coo = torch_coo.coalesce();
torch::Tensor indices = coalesced_coo.indices(); return SparseMatrix::FromCOO(
torch::Tensor row = indices[0]; coalesced_coo.indices(), coalesced_coo.values(), this->shape());
torch::Tensor col = indices[1];
return SparseMatrix::FromCOO(row, col, coalesced_coo.values(), this->shape());
} }
bool SparseMatrix::HasDuplicate() { bool SparseMatrix::HasDuplicate() {
......
...@@ -109,12 +109,35 @@ class SparseMatrix: ...@@ -109,12 +109,35 @@ class SparseMatrix:
-------- --------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]]) >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src) >>> A = dglsp.spmatrix(indices)
>>> A.coo() >>> A.coo()
(tensor([1, 2, 1]), tensor([2, 4, 3])) (tensor([1, 2, 1]), tensor([2, 4, 3]))
""" """
return self.c_sparse_matrix.coo() return self.c_sparse_matrix.coo()
def indices(self) -> torch.Tensor:
r"""Returns the coordinate list (COO) representation in one tensor with
shape ``(2, nnz)``.
See `COO in Wikipedia <https://en.wikipedia.org/wiki/
Sparse_matrix#Coordinate_list_(COO)>`_.
Returns
-------
torch.Tensor
Stacked COO tensor with shape ``(2, nnz)``.
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = dglsp.spmatrix(indices)
>>> A.indices()
tensor([[1, 2, 1],
[2, 4, 3]])
"""
return self.c_sparse_matrix.indices()
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Returns the compressed sparse row (CSR) representation of the sparse r"""Returns the compressed sparse row (CSR) representation of the sparse
matrix. matrix.
...@@ -140,7 +163,7 @@ class SparseMatrix: ...@@ -140,7 +163,7 @@ class SparseMatrix:
-------- --------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]]) >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src) >>> A = dglsp.spmatrix(indices)
>>> A.csr() >>> A.csr()
(tensor([0, 0, 2, 3]), tensor([2, 3, 4]), tensor([0, 2, 1])) (tensor([0, 0, 2, 3]), tensor([2, 3, 4]), tensor([0, 2, 1]))
""" """
...@@ -171,7 +194,7 @@ class SparseMatrix: ...@@ -171,7 +194,7 @@ class SparseMatrix:
-------- --------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]]) >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src) >>> A = dglsp.spmatrix(indices)
>>> A.csc() >>> A.csc()
(tensor([0, 0, 0, 1, 2, 3]), tensor([1, 1, 2]), tensor([0, 2, 1])) (tensor([0, 0, 0, 1, 2, 3]), tensor([1, 1, 2]), tensor([0, 2, 1]))
""" """
...@@ -521,7 +544,18 @@ def spmatrix( ...@@ -521,7 +544,18 @@ def spmatrix(
[3., 3.]]), [3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,)) shape=(3, 5), nnz=3, val_size=(2,))
""" """
return from_coo(indices[0], indices[1], val, shape) if shape is None:
shape = (
torch.max(indices[0]).item() + 1,
torch.max(indices[1]).item() + 1,
)
if val is None:
val = torch.ones(indices.shape[1]).to(indices.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(indices, val, shape))
def from_coo( def from_coo(
...@@ -599,16 +633,8 @@ def from_coo( ...@@ -599,16 +633,8 @@ def from_coo(
[3., 3.]]), [3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,)) shape=(3, 5), nnz=3, val_size=(2,))
""" """
if shape is None: assert row.shape[0] == col.shape[0]
shape = (torch.max(row).item() + 1, torch.max(col).item() + 1) return spmatrix(torch.stack([row, col]), val, shape)
if val is None:
val = torch.ones(row.shape[0]).to(row.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape))
def from_csr( def from_csr(
...@@ -833,6 +859,171 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix: ...@@ -833,6 +859,171 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val)) return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:
"""Creates a sparse matrix from a torch sparse tensor, which can have coo,
csr, or csc layout.
Parameters
----------
torch_sparse_tensor : torch.Tensor
Torch sparse tensor
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> torch_coo = torch.sparse_coo_tensor(indices, val)
>>> dglsp.from_torch_sparse(torch_coo)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
"""
assert torch_sparse_tensor.layout in (
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
), (
f"Cannot convert Pytorch sparse tensor with layout "
f"{torch_sparse_tensor.layout} to DGL sparse."
)
if torch_sparse_tensor.layout == torch.sparse_coo:
# Use ._indices() and ._values() to access uncoalesced indices and
# values.
return spmatrix(
torch_sparse_tensor._indices(),
torch_sparse_tensor._values(),
torch_sparse_tensor.shape[:2],
)
elif torch_sparse_tensor.layout == torch.sparse_csr:
return from_csr(
torch_sparse_tensor.crow_indices(),
torch_sparse_tensor.col_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
else:
return from_csc(
torch_sparse_tensor.ccol_indices(),
torch_sparse_tensor.row_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
def to_torch_sparse_coo(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse coo tensor from a sparse matrix.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
torch tensor with torch.sparse_coo layout
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_coo(spmat)
tensor(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
size=(3, 5), nnz=3, layout=torch.sparse_coo)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
return torch.sparse_coo_tensor(spmat.indices(), spmat.val, shape)
def to_torch_sparse_csr(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csr tensor from a sparse matrix.
Note that converting a sparse matrix to torch csr tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csr layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csr(spmat)
tensor(crow_indices=tensor([0, 0, 2, 3]),
col_indices=tensor([2, 3, 4]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csr)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csr()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csr_tensor(indptr, indices, val, shape)
def to_torch_sparse_csc(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csc tensor from a sparse matrix.
Note that converting a sparse matrix to torch csc tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csc layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csc(spmat)
tensor(ccol_indices=tensor([0, 0, 0, 1, 2, 3]),
row_indices=tensor([1, 1, 2]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csc)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csc()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csc_tensor(indptr, indices, val, shape)
def _sparse_matrix_str(spmat: SparseMatrix) -> str: def _sparse_matrix_str(spmat: SparseMatrix) -> str:
"""Internal function for converting a sparse matrix to string """Internal function for converting a sparse matrix to string
representation. representation.
......
...@@ -5,7 +5,16 @@ import backend as F ...@@ -5,7 +5,16 @@ import backend as F
import pytest import pytest
import torch import torch
from dgl.sparse import from_coo, from_csc, from_csr, val_like from dgl.sparse import (
from_coo,
from_csc,
from_csr,
from_torch_sparse,
to_torch_sparse_coo,
to_torch_sparse_csc,
to_torch_sparse_csr,
val_like,
)
@pytest.mark.parametrize("dense_dim", [None, 4]) @pytest.mark.parametrize("dense_dim", [None, 4])
...@@ -502,3 +511,98 @@ def test_sparse_matrix_transpose(dense_dim, row, col, extra_shape): ...@@ -502,3 +511,98 @@ def test_sparse_matrix_transpose(dense_dim, row, col, extra_shape):
assert torch.allclose(mat_val, val) assert torch.allclose(mat_val, val)
assert torch.allclose(mat_row, col) assert torch.allclose(mat_row, col)
assert torch.allclose(mat_col, row) assert torch.allclose(mat_col, row)
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("nz_dim", [None, 2])
@pytest.mark.parametrize("shape", [(5, 5), (6, 7)])
def test_torch_sparse_coo_conversion(row, col, nz_dim, shape):
dev = F.ctx()
row = torch.tensor(row).to(dev)
col = torch.tensor(col).to(dev)
indices = torch.stack([row, col])
torch_sparse_shape = shape
val_shape = (row.shape[0],)
if nz_dim is not None:
torch_sparse_shape += (nz_dim,)
val_shape += (nz_dim,)
val = torch.randn(val_shape).to(dev)
torch_sparse_coo = torch.sparse_coo_tensor(indices, val, torch_sparse_shape)
spmat = from_torch_sparse(torch_sparse_coo)
def _assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo):
assert torch_sparse_coo.layout == torch.sparse_coo
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert (
spmat.indices().data_ptr() == torch_sparse_coo._indices().data_ptr()
)
assert spmat.val.data_ptr() == torch_sparse_coo._values().data_ptr()
assert spmat.shape == torch_sparse_coo.shape[:2]
_assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)
torch_sparse_coo = to_torch_sparse_coo(spmat)
_assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)
@pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("shape", [(3, 5), (3, 7)])
def test_torch_sparse_csr_conversion(indptr, indices, shape):
dev = F.ctx()
indptr = torch.tensor(indptr).to(dev)
indices = torch.tensor(indices).to(dev)
torch_sparse_shape = shape
val_shape = (indices.shape[0],)
val = torch.randn(val_shape).to(dev)
torch_sparse_csr = torch.sparse_csr_tensor(
indptr, indices, val, torch_sparse_shape
)
spmat = from_torch_sparse(torch_sparse_csr)
def _assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr):
indptr, indices, value_indices = spmat.csr()
assert torch_sparse_csr.layout == torch.sparse_csr
assert value_indices is None
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert indptr.data_ptr() == torch_sparse_csr.crow_indices().data_ptr()
assert indices.data_ptr() == torch_sparse_csr.col_indices().data_ptr()
assert spmat.val.data_ptr() == torch_sparse_csr.values().data_ptr()
assert spmat.shape == torch_sparse_csr.shape[:2]
_assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)
torch_sparse_csr = to_torch_sparse_csr(spmat)
_assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)
@pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)])
@pytest.mark.parametrize("shape", [(8, 3), (5, 3)])
def test_torch_sparse_csc_conversion(indptr, indices, shape):
dev = F.ctx()
indptr = torch.tensor(indptr).to(dev)
indices = torch.tensor(indices).to(dev)
torch_sparse_shape = shape
val_shape = (indices.shape[0],)
val = torch.randn(val_shape).to(dev)
torch_sparse_csc = torch.sparse_csc_tensor(
indptr, indices, val, torch_sparse_shape
)
spmat = from_torch_sparse(torch_sparse_csc)
def _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc):
indptr, indices, value_indices = spmat.csc()
assert torch_sparse_csc.layout == torch.sparse_csc
assert value_indices is None
# Use .data_ptr() to check whether indices and values are on the same
# memory address
assert indptr.data_ptr() == torch_sparse_csc.ccol_indices().data_ptr()
assert indices.data_ptr() == torch_sparse_csc.row_indices().data_ptr()
assert spmat.val.data_ptr() == torch_sparse_csc.values().data_ptr()
assert spmat.shape == torch_sparse_csc.shape[:2]
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
torch_sparse_csc = to_torch_sparse_csc(spmat)
_assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment