Unverified Commit 454537d1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] fix doc and assert and test around percent (#1067)

parent 1a8d234d
......@@ -435,11 +435,11 @@ def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor:
dense (Tensor):
Input dense tensor (no zeros).
percent (float):
Percent of non-zeros.
Percent of non-zeros (0, 100].
dim (int):
Dimension on which the random sparse mask is computed.
"""
assert percent > 0, percent
assert percent > 0 and percent <= 100, percent
rand = torch.rand_like(dense)
ones = torch.ones_like(dense)
k = _get_k_for_topk(percent, None, dense.shape[dim])
......
......@@ -436,10 +436,10 @@ def test_random_sparse_mask(device):
pytest.skip("no GPU")
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device)
mask = random_sparse_mask(dense, 0.2, 0)
mask = random_sparse_mask(dense, 20, 0)
assert mask.sum() == 1
for d in [0, 1]:
dense = torch.rand(100, 100).to(device)
mask = random_sparse_mask(dense, 0.01, d)
mask = random_sparse_mask(dense, 1, d)
assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True)
assert mask.sum() == 100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment