"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2e833520618dc460cbeb693e29e40b65a02ccafb"
Unverified Commit 60b02f2f authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add SparseMatrix transposition (#4940)

parent d29b312c
...@@ -81,6 +81,10 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo); ...@@ -81,6 +81,10 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);
/** @brief Convert a CSR format to CSC format. */ /** @brief Convert a CSR format to CSC format. */
std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr); std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);
/** @brief COO transposition. */
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
...@@ -117,6 +117,11 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -117,6 +117,11 @@ class SparseMatrix : public torch::CustomClassHolder {
*/ */
void SetValue(torch::Tensor value); void SetValue(torch::Tensor value);
/** @brief Return the transposition of the sparse matrix. It transposes the
* first existing sparse format by checking COO, CSR, and CSC.
*/
c10::intrusive_ptr<SparseMatrix> Transpose() const;
private: private:
/** @brief Create the COO format for the sparse matrix internally */ /** @brief Create the COO format for the sparse matrix internally */
void _CreateCOO(); void _CreateCOO();
......
...@@ -24,7 +24,8 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -24,7 +24,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("coo", &SparseMatrix::COOTensors) .def("coo", &SparseMatrix::COOTensors)
.def("csr", &SparseMatrix::CSRTensors) .def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors) .def("csc", &SparseMatrix::CSCTensors)
.def("set_val", &SparseMatrix::SetValue); .def("set_val", &SparseMatrix::SetValue)
.def("transpose", &SparseMatrix::Transpose);
m.def("create_from_coo", &CreateFromCOO) m.def("create_from_coo", &CreateFromCOO)
.def("create_from_csr", &CreateFromCSR) .def("create_from_csr", &CreateFromCSR)
.def("create_from_csc", &CreateFromCSC) .def("create_from_csc", &CreateFromCSC)
......
...@@ -86,5 +86,11 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) { ...@@ -86,5 +86,11 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
return CSRFromOldDGLCSR(dgl_csc); return CSRFromOldDGLCSR(dgl_csc);
} }
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
auto dgl_coo = COOToOldDGLCOO(coo);
auto dgl_coo_tr = aten::COOTranspose(dgl_coo);
return COOFromOldDGLCOO(dgl_coo_tr);
}
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
...@@ -114,6 +114,20 @@ SparseMatrix::CSCTensors() { ...@@ -114,6 +114,20 @@ SparseMatrix::CSCTensors() {
void SparseMatrix::SetValue(torch::Tensor value) { value_ = value; } void SparseMatrix::SetValue(torch::Tensor value) { value_ = value; }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
auto shape = shape_;
std::swap(shape[0], shape[1]);
auto value = value_;
if (HasCOO()) {
auto coo = COOTranspose(coo_);
return SparseMatrix::FromCOO(coo, value, shape);
} else if (HasCSR()) {
return SparseMatrix::FromCSC(csr_, value, shape);
} else {
return SparseMatrix::FromCSR(csc_, value, shape);
}
}
void SparseMatrix::_CreateCOO() { void SparseMatrix::_CreateCOO() {
if (HasCOO()) return; if (HasCOO()) return;
if (HasCSR()) { if (HasCSR()) {
......
...@@ -149,6 +149,39 @@ class SparseMatrix: ...@@ -149,6 +149,39 @@ class SparseMatrix:
mat[row, col] = val mat[row, col] = val
return mat return mat
def t(self):
"""Alias of :meth:`transpose()`"""
return self.transpose()
@property
def T(self): # pylint: disable=C0103
"""Alias of :meth:`transpose()`"""
return self.transpose()
def transpose(self):
"""Return the transpose of this sparse matrix.
Returns
-------
SparseMatrix
The transpose of this sparse matrix.
Example
-------
>>> row = torch.tensor([1, 1, 3])
>>> col = torch.tensor([2, 1, 3])
>>> val = torch.tensor([1, 1, 2])
>>> A = create_from_coo(row, col, val)
>>> A = A.transpose()
>>> print(A)
SparseMatrix(indices=tensor([[2, 1, 3],
[1, 1, 3]]),
values=tensor([1, 1, 2]),
shape=(4, 4), nnz=3)
"""
return SparseMatrix(self.c_sparse_matrix.transpose())
def create_from_coo( def create_from_coo(
row: torch.Tensor, row: torch.Tensor,
......
...@@ -4,7 +4,7 @@ import sys ...@@ -4,7 +4,7 @@ import sys
from dgl.mock_sparse2 import diag, identity, DiagMatrix from dgl.mock_sparse2 import diag, identity, DiagMatrix
# FIXME: Skipping tests on win. # FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True) pytest.skip("skipping tests on win", allow_module_level=True)
......
...@@ -4,9 +4,9 @@ import numpy as np ...@@ -4,9 +4,9 @@ import numpy as np
import pytest import pytest
import torch import torch
import sys import sys
from dgl.mock_sparse import diag from dgl.mock_sparse2 import diag
# FIXME: Skipping tests on win. # FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True) pytest.skip("skipping tests on win", allow_module_level=True)
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
import dgl import dgl
from dgl.mock_sparse2 import create_from_coo, diag from dgl.mock_sparse2 import create_from_coo, diag
# FIXME: Skipping tests on win. # FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True) pytest.skip("skipping tests on win", allow_module_level=True)
......
...@@ -6,7 +6,7 @@ import backend as F ...@@ -6,7 +6,7 @@ import backend as F
from dgl.mock_sparse2 import create_from_coo, create_from_csr, create_from_csc from dgl.mock_sparse2 import create_from_coo, create_from_csr, create_from_csc
# FIXME: Skipping tests on win. # FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True) pytest.skip("skipping tests on win", allow_module_level=True)
......
import pytest
import torch
import sys
from dgl.mock_sparse2 import diag, create_from_coo
import backend as F
# FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
pytest.skip("skipping tests on win", allow_module_level=True)
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
ctx = F.ctx()
val = torch.randn(val_shape).to(ctx)
mat = diag(val, mat_shape).transpose()
assert torch.allclose(mat.val, val)
if mat_shape is None:
mat_shape = (val_shape[0], val_shape[0])
assert mat.shape == mat_shape[::-1]
@pytest.mark.parametrize("dense_dim", [None, 2])
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("extra_shape", [(0, 1), (2, 1)])
def test_sparse_matrix_transpose(dense_dim, row, col, extra_shape):
mat_shape = (max(row) + 1 + extra_shape[0], max(col) + 1 + extra_shape[1])
val_shape = (len(row),)
if dense_dim is not None:
val_shape += (dense_dim,)
ctx = F.ctx()
val = torch.randn(val_shape).to(ctx)
row = torch.tensor(row).to(ctx)
col = torch.tensor(col).to(ctx)
mat = create_from_coo(row, col, val, mat_shape).transpose()
mat_row, mat_col = mat.coo()
mat_val = mat.val
assert mat.shape == mat_shape[::-1]
assert torch.allclose(mat_val, val)
assert torch.allclose(mat_row, col)
assert torch.allclose(mat_col, row)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment