Unverified Commit b4c351b4 authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Add sparse sample python API (#6287)

parent fa3f2f48
...@@ -586,6 +586,93 @@ class SparseMatrix: ...@@ -586,6 +586,93 @@ class SparseMatrix:
) )
raise TypeError(f"{type(index).__name__} is unsupported input type.") raise TypeError(f"{type(index).__name__} is unsupported input type.")
def sample(
self,
dim: int,
fanout: int,
ids: Optional[torch.Tensor] = None,
replace: Optional[bool] = False,
bias: Optional[bool] = False,
):
"""Returns a sampled matrix on the given dimension and sample arguments.
Parameters
----------
dim : int
The dimension for sampling, should be 0 or 1. `dim = 0` for
rowwise selection and `dim = 1` for columnwise selection.
fanout : int
The number of elements to randomly sample on each row or column.
ids : torch.Tensor, optional
An optional tensor containing row or column IDs from which to
sample elements.
NOTE: If `ids` is not provided (i.e., `ids = None`), the function
will sample from all rows or columns.
replace : bool, optional
Indicates whether repeated sampling of the same element is allowed.
When `replace = True`, repeated sampling is permitted; when
`replace = False`, it is not allowed.
NOTE: If `replace = False` and there are fewer elements than
`fanout`, all non-zero elements will be sampled.
bias : bool, optional
A boolean flag indicating whether to enable biasing during sampling.
When `bias = True`, the values of the sparse matrix will be used as
bias weights.
The function does not support autograd.
Returns
-------
SparseMatrix
A submatrix with the same shape as the original matrix, containing
the randomly sampled non-zero elements.
Examples
--------
>>> indices = torch.tensor([[0, 0, 1, 1, 2, 2, 2],
[0, 2, 0, 1, 0, 1, 2]])
>>> val = torch.tensor([0, 1, 2, 3, 4, 5, 6])
>>> A = dglsp.spmatrix(indices, val)
Case 1: Sample rows with the given number and disable repeated sampling.
>>> row_ids = torch.tensor([0, 2])
>>> A.sample(0, 2, row_ids)
SparseMatrix(indices=tensor([[0, 0, 1, 1],
[0, 2, 0, 2]]),
values=tensor([0, 1, 4, 6]),
shape=(2, 3), nnz=4)
Case 2: Sample cols with the given number and disable repeated sampling.
>>> col_ids = torch.tensor([0, 2])
>>> A.sample(1, 2, col_ids)
SparseMatrix(indices=tensor([[0, 1, 0, 2],
[0, 0, 1, 1]]),
values=tensor([0, 2, 1, 6]),
shape=(3, 2), nnz=4)
Case 3: Sample rows with the given number and enable repeated sampling.
>>> row_ids = torch.tensor([0, 1])
>>> A.sample(0, 2, row_ids, True)
SparseMatrix(indices=tensor([[0, 0, 1, 1],
[0, 2, 0, 0]]),
values=tensor([0, 1, 2, 2]),
shape=(2, 3), nnz=3)
Case 4: Sample cols with the given number and enable repeated sampling.
>>> col_ids = torch.tensor([0, 1])
>>> A.sample(1, 2, col_ids, True)
SparseMatrix(indices=tensor([[0, 1, 1, 1],
[0, 0, 1, 1]]),
values=tensor([0, 2, 3, 3]),
shape=(3, 2), nnz=3)
"""
raise NotImplementedError
def spmatrix( def spmatrix(
indices: torch.Tensor, indices: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment