Unverified Commit e088acac authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Add val_like and Disable Setting Nonzero Values (#4972)



* Update

* Update
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent b1e2695f
...@@ -111,12 +111,6 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -111,12 +111,6 @@ class SparseMatrix : public torch::CustomClassHolder {
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
CSCTensors(); CSCTensors();
/**
* @brief Set non-zero values of the sparse matrix
* @param value Values of the sparse matrix
*/
void SetValue(torch::Tensor value);
/** @brief Return the transposition of the sparse matrix. It transposes the /** @brief Return the transposition of the sparse matrix. It transposes the
* first existing sparse format by checking COO, CSR, and CSC. * first existing sparse format by checking COO, CSR, and CSC.
*/ */
...@@ -178,6 +172,16 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSC( ...@@ -178,6 +172,16 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
* @param value New values of the sparse matrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> CreateValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
...@@ -24,12 +24,12 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -24,12 +24,12 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("coo", &SparseMatrix::COOTensors) .def("coo", &SparseMatrix::COOTensors)
.def("csr", &SparseMatrix::CSRTensors) .def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors) .def("csc", &SparseMatrix::CSCTensors)
.def("set_val", &SparseMatrix::SetValue)
.def("transpose", &SparseMatrix::Transpose); .def("transpose", &SparseMatrix::Transpose);
m.def("create_from_coo", &CreateFromCOO) m.def("create_from_coo", &CreateFromCOO)
.def("create_from_csr", &CreateFromCSR) .def("create_from_csr", &CreateFromCSR)
.def("create_from_csc", &CreateFromCSC) .def("create_from_csc", &CreateFromCSC)
.def("spsp_add", &SpSpAdd); .def("spsp_add", &SpSpAdd)
.def("val_like", &CreateValLike);
} }
} // namespace sparse } // namespace sparse
......
...@@ -112,8 +112,6 @@ SparseMatrix::CSCTensors() { ...@@ -112,8 +112,6 @@ SparseMatrix::CSCTensors() {
return {csc->indptr, csc->indices, csc->value_indices}; return {csc->indptr, csc->indices, csc->value_indices};
} }
void SparseMatrix::SetValue(torch::Tensor value) { value_ = value; }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const { c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
auto shape = shape_; auto shape = shape_;
std::swap(shape[0], shape[1]); std::swap(shape[0], shape[1]);
...@@ -187,5 +185,22 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSC( ...@@ -187,5 +185,22 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
return SparseMatrix::FromCSC(csc, value, shape); return SparseMatrix::FromCSC(csc, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> CreateValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
CHECK_EQ(mat->value().size(0), value.size(0))
<< "The first dimension of the old values and the new values must be the "
"same.";
CHECK_EQ(mat->value().device(), value.device())
<< "The device of the old values and the new values must be the same.";
auto shape = mat->shape();
if (mat->HasCOO()) {
return SparseMatrix::FromCOO(mat->COOPtr(), value, shape);
} else if (mat->HasCSR()) {
return SparseMatrix::FromCSR(mat->CSRPtr(), value, shape);
} else {
return SparseMatrix::FromCSC(mat->CSCPtr(), value, shape);
}
}
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
...@@ -36,8 +36,30 @@ class DiagMatrix: ...@@ -36,8 +36,30 @@ class DiagMatrix:
) )
else: else:
shape = (len_val, len_val) shape = (len_val, len_val)
self.val = val self._val = val
self.shape = shape self._shape = shape
@property
def val(self) -> torch.Tensor:
"""Get the values of the nonzero elements.
Returns
-------
torch.Tensor
Values of the nonzero elements
"""
return self._val
@property
def shape(self) -> Tuple[int]:
"""Shape of the sparse matrix.
Returns
-------
Tuple[int]
The shape of the matrix
"""
return self._shape
def __repr__(self): def __repr__(self):
return f"DiagMatrix(val={self.val}, \nshape={self.shape})" return f"DiagMatrix(val={self.val}, \nshape={self.shape})"
......
...@@ -21,17 +21,6 @@ class SparseMatrix: ...@@ -21,17 +21,6 @@ class SparseMatrix:
""" """
return self.c_sparse_matrix.val() return self.c_sparse_matrix.val()
@val.setter
def val(self, x: torch.Tensor):
"""Set the non-zero values inplace.
Parameters
----------
x : torch.Tensor, optional
The values of shape (nnz) or (nnz, D)
"""
self.c_sparse_matrix.set_val(x)
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
"""Shape of the sparse matrix. """Shape of the sparse matrix.
...@@ -95,7 +84,7 @@ class SparseMatrix: ...@@ -95,7 +84,7 @@ class SparseMatrix:
Indices of the nonzero elements Indices of the nonzero elements
""" """
if fmt == "COO" and not return_shuffle: if fmt == "COO" and not return_shuffle:
row, col, _ = self.coo() row, col = self.coo()
return torch.stack([row, col]) return torch.stack([row, col])
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -120,7 +109,8 @@ class SparseMatrix: ...@@ -120,7 +109,8 @@ class SparseMatrix:
Returns Returns
------- -------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and value indices. A tuple of tensors containing row, column coordinates and value
indices.
""" """
return self.c_sparse_matrix.csr() return self.c_sparse_matrix.csr()
...@@ -130,7 +120,8 @@ class SparseMatrix: ...@@ -130,7 +120,8 @@ class SparseMatrix:
Returns Returns
------- -------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and value indices. A tuple of tensors containing row, column coordinates and value
indices.
""" """
return self.c_sparse_matrix.csc() return self.c_sparse_matrix.csc()
...@@ -419,3 +410,38 @@ def create_from_csc( ...@@ -419,3 +410,38 @@ def create_from_csc(
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.create_from_csc(indptr, indices, val, shape) torch.ops.dgl_sparse.create_from_csc(indptr, indices, val, shape)
) )
def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
"""Create a sparse matrix from an existing sparse matrix using new values.
The new sparse matrix will have the same nonzero indices as the given
sparse matrix and use the given values as the new nonzero values.
Parameters
----------
mat : SparseMatrix
An existing sparse matrix with nnz nonzero values
val : tensor
The new nonzero values, a tensor of shape (nnz) or (nnz, D)
Returns
-------
SparseMatrix
New sparse matrix
Examples
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([2, 4, 3])
>>> val = torch.ones(3)
>>> A = create_from_coo(row, col, val)
>>> B = val_like(A, torch.tensor([2, 2, 2]))
>>> print(B)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([2, 2, 2]),
shape=(3, 5), nnz=3)
"""
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
...@@ -4,7 +4,7 @@ import sys ...@@ -4,7 +4,7 @@ import sys
import backend as F import backend as F
from dgl.mock_sparse2 import create_from_coo, create_from_csr, create_from_csc from dgl.mock_sparse2 import create_from_coo, create_from_csr, create_from_csc, val_like
# FIXME(issue #4818): Skipping tests on win. # FIXME(issue #4818): Skipping tests on win.
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
...@@ -111,20 +111,6 @@ def test_dense(val_shape): ...@@ -111,20 +111,6 @@ def test_dense(val_shape):
assert torch.allclose(A_dense, mat) assert torch.allclose(A_dense, mat)
def test_set_val():
ctx = F.ctx()
row = torch.tensor([1, 1, 2]).to(ctx)
col = torch.tensor([2, 4, 3]).to(ctx)
nnz = len(row)
old_val = torch.ones(nnz).to(ctx)
A = create_from_coo(row, col, old_val)
new_val = torch.zeros(nnz).to(ctx)
A.val = new_val
assert torch.allclose(new_val, A.val)
@pytest.mark.parametrize("dense_dim", [None, 4]) @pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)]) @pytest.mark.parametrize("indptr", [(0, 0, 1, 4), (0, 1, 2, 4)])
@pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 4, 3, 2)]) @pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 4, 3, 2)])
...@@ -350,3 +336,35 @@ def test_csr_to_csc(dense_dim, indptr, indices, shape): ...@@ -350,3 +336,35 @@ def test_csr_to_csc(dense_dim, indptr, indices, shape):
assert torch.allclose(mat_val, val) assert torch.allclose(mat_val, val)
assert torch.allclose(mat_indptr, indptr) assert torch.allclose(mat_indptr, indptr)
assert torch.allclose(mat_indices, indices) assert torch.allclose(mat_indices, indices)
@pytest.mark.parametrize("val_shape", [(3), (3, 2)])
@pytest.mark.parametrize("shape", [(3, 5), (5, 5)])
def test_val_like(val_shape, shape):
def check_val_like(A, B):
assert A.shape == B.shape
assert A.nnz == B.nnz
assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))
assert A.val.device == B.val.device
ctx = F.ctx()
# COO
row = torch.tensor([1, 1, 2]).to(ctx)
col = torch.tensor([2, 4, 3]).to(ctx)
val = torch.randn(3).to(ctx)
coo_A = create_from_coo(row, col, val, shape)
new_val = torch.randn(val_shape).to(ctx)
coo_B = val_like(coo_A, new_val)
check_val_like(coo_A, coo_B)
# CSR
indptr, indices, _ = coo_A.csr()
csr_A = create_from_csr(indptr, indices, val, shape)
csr_B = val_like(csr_A, new_val)
check_val_like(csr_A, csr_B)
# CSC
indptr, indices, _ = coo_A.csc()
csc_A = create_from_csc(indptr, indices, val, shape)
csc_B = val_like(csc_A, new_val)
check_val_like(csc_A, csc_B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment