"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bc8fd864ebdf867709586bccef38225360c2a6b5"
Unverified Commit dadce86a authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Add sparse matrix slicing operator implementation (#6208)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 729924e3
...@@ -137,6 +137,51 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -137,6 +137,51 @@ class SparseMatrix : public torch::CustomClassHolder {
static c10::intrusive_ptr<SparseMatrix> FromDiag( static c10::intrusive_ptr<SparseMatrix> FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape); torch::Tensor value, const std::vector<int64_t>& shape);
/**
* @brief Create a SparseMatrix by selecting rows or columns based on provided
* indices.
*
* This function allows you to create a new SparseMatrix by selecting specific
* rows or columns from the original SparseMatrix based on the provided
* indices. The selection can be performed either row-wise or column-wise,
* determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param ids A tensor containing the indices of the selected rows or columns.
*
* @return A new SparseMatrix containing the selected rows or columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'ids' tensor should contain valid indices within the range of the
* original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);
/**
* @brief Create a SparseMatrix by selecting a range of rows or columns based
* on provided indices.
*
* This function allows you to create a new SparseMatrix by selecting a range
* of specific rows or columns from the original SparseMatrix based on the
* provided indices. The selection can be performed either row-wise or
* column-wise, determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param start The starting index (inclusive) of the range.
* @param end The ending index (exclusive) of the range.
*
* @return A new SparseMatrix containing the selected range of rows or
* columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'start' and 'end' indices should be valid indices within
* the valid range of the original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> RangeSelect(
int64_t dim, int64_t start, int64_t end);
/** /**
* @brief Create a SparseMatrix from a SparseMatrix using new values. * @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix * @param mat An existing sparse matrix
......
...@@ -33,7 +33,9 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -33,7 +33,9 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("transpose", &SparseMatrix::Transpose) .def("transpose", &SparseMatrix::Transpose)
.def("coalesce", &SparseMatrix::Coalesce) .def("coalesce", &SparseMatrix::Coalesce)
.def("has_duplicate", &SparseMatrix::HasDuplicate) .def("has_duplicate", &SparseMatrix::HasDuplicate)
.def("is_diag", &SparseMatrix::HasDiag); .def("is_diag", &SparseMatrix::HasDiag)
.def("index_select", &SparseMatrix::IndexSelect)
.def("range_select", &SparseMatrix::RangeSelect);
m.def("from_coo", &SparseMatrix::FromCOO) m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR) .def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC) .def("from_csc", &SparseMatrix::FromCSC)
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <sparse/sparse_matrix.h> #include <sparse/sparse_matrix.h>
#include <torch/script.h> #include <torch/script.h>
#include "./utils.h"
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
...@@ -122,6 +124,49 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag( ...@@ -122,6 +124,49 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
return SparseMatrix::FromDiagPointer(diag, value, shape); return SparseMatrix::FromDiagPointer(diag, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
int64_t dim, torch::Tensor ids) {
auto id_array = TorchTensorToDGLArray(ids);
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
int64_t dim, int64_t start, int64_t end) {
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike( c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) { const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK( TORCH_CHECK(
......
...@@ -487,7 +487,7 @@ class SparseMatrix: ...@@ -487,7 +487,7 @@ class SparseMatrix:
dim : int dim : int
The dim to select from matrix, should be 0 or 1. `dim = 0` for The dim to select from matrix, should be 0 or 1. `dim = 0` for
rowwise selection and `dim = 1` for columnwise selection. rowwise selection and `dim = 1` for columnwise selection.
index : tensor.Tensor index : torch.Tensor
The selection index indicates which IDs from the `dim` should The selection index indicates which IDs from the `dim` should
be chosen from the matrix. be chosen from the matrix.
Note that duplicated ids are allowed. Note that duplicated ids are allowed.
...@@ -527,7 +527,7 @@ class SparseMatrix: ...@@ -527,7 +527,7 @@ class SparseMatrix:
if dim not in (0, 1): if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.") raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, torch.Tensor): if isinstance(index, torch.Tensor):
raise NotImplementedError return SparseMatrix(self.c_sparse_matrix.index_select(dim, index))
raise TypeError(f"{type(index).__name__} is unsupported input type.") raise TypeError(f"{type(index).__name__} is unsupported input type.")
def range_select(self, dim: int, index: slice): def range_select(self, dim: int, index: slice):
...@@ -575,7 +575,15 @@ class SparseMatrix: ...@@ -575,7 +575,15 @@ class SparseMatrix:
if dim not in (0, 1): if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.") raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, slice): if isinstance(index, slice):
raise NotImplementedError if index.step not in (None, 1):
raise NotImplementedError(
"Slice with step other than 1 are not supported yet."
)
start = 0 if index.start is None else index.start
end = index.stop
return SparseMatrix(
self.c_sparse_matrix.range_select(dim, start, end)
)
raise TypeError(f"{type(index).__name__} is unsupported input type.") raise TypeError(f"{type(index).__name__} is unsupported input type.")
......
...@@ -18,6 +18,14 @@ from dgl.sparse import ( ...@@ -18,6 +18,14 @@ from dgl.sparse import (
val_like, val_like,
) )
from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)
def _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape): def _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape):
with warnings.catch_warnings(): with warnings.catch_warnings():
...@@ -450,6 +458,52 @@ def test_has_duplicate(): ...@@ -450,6 +458,52 @@ def test_has_duplicate():
assert csc_A.has_duplicate() assert csc_A.has_duplicate()
@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("index", [(0, 1, 3), (1, 2)])
def test_index_select(create_func, shape, dense_dim, select_dim, index):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
index = torch.tensor(index).to(ctx)
A_select = A.index_select(select_dim, index)
dense = sparse_matrix_to_dense(A)
dense_select = torch.index_select(dense, select_dim, index)
A_select_to_dense = sparse_matrix_to_dense(A_select)
assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)
@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("rang", [slice(0, 2), slice(1, 3)])
def test_range_select(create_func, shape, dense_dim, select_dim, rang):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
A_select = A.range_select(select_dim, rang)
dense = sparse_matrix_to_dense(A)
if select_dim == 0:
dense_select = dense[rang, :]
else:
dense_select = dense[:, rang]
A_select_to_dense = sparse_matrix_to_dense(A_select)
assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)
def test_print(): def test_print():
ctx = F.ctx() ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment