Unverified Commit 454537d1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] fix doc and assert and test around percent (#1067)

parent 1a8d234d
...@@ -435,11 +435,11 @@ def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor: ...@@ -435,11 +435,11 @@ def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor:
dense (Tensor): dense (Tensor):
Input dense tensor (no zeros). Input dense tensor (no zeros).
percent (float): percent (float):
Percent of non-zeros. Percent of non-zeros (0, 100].
dim (int): dim (int):
Dimension on which the random sparse mask is computed. Dimension on which the random sparse mask is computed.
""" """
assert percent > 0, percent assert percent > 0 and percent <= 100, percent
rand = torch.rand_like(dense) rand = torch.rand_like(dense)
ones = torch.ones_like(dense) ones = torch.ones_like(dense)
k = _get_k_for_topk(percent, None, dense.shape[dim]) k = _get_k_for_topk(percent, None, dense.shape[dim])
......
...@@ -436,10 +436,10 @@ def test_random_sparse_mask(device): ...@@ -436,10 +436,10 @@ def test_random_sparse_mask(device):
pytest.skip("no GPU") pytest.skip("no GPU")
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device) dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device)
mask = random_sparse_mask(dense, 0.2, 0) mask = random_sparse_mask(dense, 20, 0)
assert mask.sum() == 1 assert mask.sum() == 1
for d in [0, 1]: for d in [0, 1]:
dense = torch.rand(100, 100).to(device) dense = torch.rand(100, 100).to(device)
mask = random_sparse_mask(dense, 0.01, d) mask = random_sparse_mask(dense, 1, d)
assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True) assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True)
assert mask.sum() == 100 assert mask.sum() == 100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment