"src/vscode:/vscode.git/clone" did not exist on "7c05b975b79df39875959494020e4b5eedd2c4c8"
Unverified Commit a23b490d authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Sparse sample implementation (#6303)

parent 0440806a
...@@ -182,6 +182,39 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -182,6 +182,39 @@ class SparseMatrix : public torch::CustomClassHolder {
c10::intrusive_ptr<SparseMatrix> RangeSelect( c10::intrusive_ptr<SparseMatrix> RangeSelect(
int64_t dim, int64_t start, int64_t end); int64_t dim, int64_t start, int64_t end);
/**
* @brief Create a SparseMatrix by sampling elements based on the specified
* dimension and sample count.
*
* If `ids` is provided, this function samples elements from the specified
* set of row or column IDs, resulting in a sparse matrix containing only
* the sampled rows or columns.
*
* @param dim Select rows (dim=0) or columns (dim=1) for sampling.
* @param fanout The number of elements to randomly sample from each row or
* column.
* @param ids An optional tensor containing row or column IDs from which to
* sample elements.
* @param replace Indicates whether repeated sampling of the same element
* is allowed. If True, repeated sampling is allowed; otherwise, it is not
* allowed.
* @param bias An optional boolean flag indicating whether to enable biasing
* during sampling. If True, the values of the sparse matrix will be used as
* bias weights, meaning that elements with higher values will be more likely
* to be sampled. Otherwise, all elements will be sampled uniformly,
* regardless of their value.
*
* @return A new SparseMatrix with the same shape as the original matrix
* containing the sampled elements.
*
* @note If 'replace = false' and there are fewer elements than 'fanout',
* all non-zero elements will be sampled.
* @note If 'ids' is not provided, the function will sample from
* all rows or columns.
*/
c10::intrusive_ptr<SparseMatrix> Sample(
int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias);
/** /**
* @brief Create a SparseMatrix from a SparseMatrix using new values. * @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix * @param mat An existing sparse matrix
......
...@@ -35,7 +35,8 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -35,7 +35,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("has_duplicate", &SparseMatrix::HasDuplicate) .def("has_duplicate", &SparseMatrix::HasDuplicate)
.def("is_diag", &SparseMatrix::HasDiag) .def("is_diag", &SparseMatrix::HasDiag)
.def("index_select", &SparseMatrix::IndexSelect) .def("index_select", &SparseMatrix::IndexSelect)
.def("range_select", &SparseMatrix::RangeSelect); .def("range_select", &SparseMatrix::RangeSelect)
.def("sample", &SparseMatrix::Sample);
m.def("from_coo", &SparseMatrix::FromCOO) m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR) .def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC) .def("from_csc", &SparseMatrix::FromCSC)
......
...@@ -167,6 +167,34 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect( ...@@ -167,6 +167,34 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
} }
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Sample(
int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias) {
bool rowwise = dim == 0;
auto id_array = TorchTensorToDGLArray(ids);
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
// Slicing matrix.
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// Reset value indices.
slice_csr.data = dgl::aten::NullArray();
auto prob =
bias ? TorchTensorToDGLArray(slice_value) : dgl::aten::NullArray();
auto slice_id =
dgl::aten::Range(0, id_array.NumElements(), 64, id_array->ctx);
// Sampling all rows on sliced matrix.
auto sample_coo =
dgl::aten::CSRRowWiseSampling(slice_csr, slice_id, fanout, prob, replace);
auto sample_value =
slice_value.index_select(0, DGLArrayToTorchTensor(sample_coo.data));
sample_coo.data = dgl::aten::NullArray();
auto ret = COOFromOldDGLCOO(sample_coo);
if (!rowwise) ret = COOTranspose(ret);
return SparseMatrix::FromCOOPointer(
ret, sample_value, {ret->num_rows, ret->num_cols});
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike( c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) { const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK( TORCH_CHECK(
......
...@@ -671,7 +671,14 @@ class SparseMatrix: ...@@ -671,7 +671,14 @@ class SparseMatrix:
values=tensor([0, 2, 3, 3]), values=tensor([0, 2, 3, 3]),
shape=(3, 2), nnz=3) shape=(3, 2), nnz=3)
""" """
raise NotImplementedError if ids is None:
dim_size = self.shape[0] if dim == 0 else self.shape[1]
ids = torch.range(
0, dim_size, dtype=torch.int64, device=self.device
)
return SparseMatrix(
self.c_sparse_matrix.sample(dim, fanout, ids, replace, bias)
)
def spmatrix( def spmatrix(
......
...@@ -504,6 +504,90 @@ def test_range_select(create_func, shape, dense_dim, select_dim, rang): ...@@ -504,6 +504,90 @@ def test_range_select(create_func, shape, dense_dim, select_dim, rang):
assert torch.allclose(A_select_to_dense, dense_select) assert torch.allclose(A_select_to_dense, dense_select)
@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("index", [(0, 1, 2, 3, 4), (0, 1, 3), (1, 1, 2)])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_sample_rowwise(create_func, index, replace, bias):
ctx = F.ctx()
shape = (5, 5)
sample_dim = 0
sample_num = 3
A = create_func(shape, 10, ctx)
A = val_like(A, torch.abs(A.val))
index = torch.tensor(index).to(ctx)
A_sample = A.sample(sample_dim, sample_num, index, replace, bias)
A_dense = sparse_matrix_to_dense(A)
A_sample_to_dense = sparse_matrix_to_dense(A_sample)
ans_shape = (index.size(0), shape[1])
# Verify sample elements in origin rows
for i, row in enumerate(list(index)):
ans_ele = list(A_dense[row, :].nonzero().reshape(-1))
ret_ele = list(A_sample_to_dense[i, :].nonzero().reshape(-1))
for e in ret_ele:
assert e in ans_ele
if replace:
# The number of sample elements in one row should be equal to
# 'sample_num' if the row is not empty otherwise should be
# equal to 0.
assert list(A_sample.row).count(torch.tensor(i)) == (
sample_num if len(ans_ele) != 0 else 0
)
else:
assert len(ret_ele) == min(sample_num, len(ans_ele))
assert A_sample.shape == ans_shape
if not replace:
assert not A_sample.has_duplicate()
@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("index", [(0, 1, 2, 3, 4), (0, 1, 3), (1, 1, 2)])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_sample_columnwise(create_func, index, replace, bias):
ctx = F.ctx()
shape = (5, 5)
sample_dim = 1
sample_num = 3
A = create_func(shape, 10, ctx)
A = val_like(A, torch.abs(A.val))
index = torch.tensor(index).to(ctx)
A_sample = A.sample(sample_dim, sample_num, index, replace, bias)
A_dense = sparse_matrix_to_dense(A)
A_sample_to_dense = sparse_matrix_to_dense(A_sample)
ans_shape = (shape[0], index.size(0))
# Verify sample elements in origin columns
for i, col in enumerate(list(index)):
ans_ele = list(A_dense[:, col].nonzero().reshape(-1))
ret_ele = list(A_sample_to_dense[:, i].nonzero().reshape(-1))
for e in ret_ele:
assert e in ans_ele
if replace:
# The number of sample elements in one column should be equal to
# 'sample_num' if the column is not empty otherwise should be
# equal to 0.
assert list(A_sample.col).count(torch.tensor(i)) == (
sample_num if len(ans_ele) != 0 else 0
)
else:
assert len(ret_ele) == min(sample_num, len(ans_ele))
assert A_sample.shape == ans_shape
if not replace:
assert not A_sample.has_duplicate()
def test_print(): def test_print():
ctx = F.ctx() ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment