Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b4c351b4
Unverified
Commit
b4c351b4
authored
Sep 07, 2023
by
xiangyuzhi
Committed by
GitHub
Sep 07, 2023
Browse files
[Sparse] Add sparse sample python API (#6287)
parent
fa3f2f48
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
0 deletions
+87
-0
python/dgl/sparse/sparse_matrix.py
python/dgl/sparse/sparse_matrix.py
+87
-0
No files found.
python/dgl/sparse/sparse_matrix.py
View file @
b4c351b4
...
@@ -586,6 +586,93 @@ class SparseMatrix:
...
@@ -586,6 +586,93 @@ class SparseMatrix:
)
)
raise
TypeError
(
f
"
{
type
(
index
).
__name__
}
is unsupported input type."
)
raise
TypeError
(
f
"
{
type
(
index
).
__name__
}
is unsupported input type."
)
def
sample
(
self
,
dim
:
int
,
fanout
:
int
,
ids
:
Optional
[
torch
.
Tensor
]
=
None
,
replace
:
Optional
[
bool
]
=
False
,
bias
:
Optional
[
bool
]
=
False
,
):
"""Returns a sampled matrix on the given dimension and sample arguments.
Parameters
----------
dim : int
The dimension for sampling, should be 0 or 1. `dim = 0` for
rowwise selection and `dim = 1` for columnwise selection.
fanout : int
The number of elements to randomly sample on each row or column.
ids : torch.Tensor, optional
An optional tensor containing row or column IDs from which to
sample elements.
NOTE: If `ids` is not provided (i.e., `ids = None`), the function
will sample from all rows or columns.
replace : bool, optional
Indicates whether repeated sampling of the same element is allowed.
When `replace = True`, repeated sampling is permitted; when
`replace = False`, it is not allowed.
NOTE: If `replace = False` and there are fewer elements than
`fanout`, all non-zero elements will be sampled.
bias : bool, optional
A boolean flag indicating whether to enable biasing during sampling.
When `bias = True`, the values of the sparse matrix will be used as
bias weights.
The function does not support autograd.
Returns
-------
SparseMatrix
A submatrix with the same shape as the original matrix, containing
the randomly sampled non-zero elements.
Examples
--------
>>> indices = torch.tensor([[0, 0, 1, 1, 2, 2, 2],
[0, 2, 0, 1, 0, 1, 2]])
>>> val = torch.tensor([0, 1, 2, 3, 4, 5, 6])
>>> A = dglsp.spmatrix(indices, val)
Case 1: Sample rows with the given number and disable repeated sampling.
>>> row_ids = torch.tensor([0, 2])
>>> A.sample(0, 2, row_ids)
SparseMatrix(indices=tensor([[0, 0, 1, 1],
[0, 2, 0, 2]]),
values=tensor([0, 1, 4, 6]),
shape=(2, 3), nnz=4)
Case 2: Sample cols with the given number and disable repeated sampling.
>>> col_ids = torch.tensor([0, 2])
>>> A.sample(1, 2, col_ids)
SparseMatrix(indices=tensor([[0, 1, 0, 2],
[0, 0, 1, 1]]),
values=tensor([0, 2, 1, 6]),
shape=(3, 2), nnz=4)
Case 3: Sample rows with the given number and enable repeated sampling.
>>> row_ids = torch.tensor([0, 1])
>>> A.sample(0, 2, row_ids, True)
SparseMatrix(indices=tensor([[0, 0, 1, 1],
[0, 2, 0, 0]]),
values=tensor([0, 1, 2, 2]),
shape=(2, 3), nnz=3)
Case 4: Sample cols with the given number and enable repeated sampling.
>>> col_ids = torch.tensor([0, 1])
>>> A.sample(1, 2, col_ids, True)
SparseMatrix(indices=tensor([[0, 1, 1, 1],
[0, 0, 1, 1]]),
values=tensor([0, 2, 3, 3]),
shape=(3, 2), nnz=3)
"""
raise
NotImplementedError
def
spmatrix
(
def
spmatrix
(
indices
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment