Unverified Commit 1a8d234d authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add random_sparse_mask api (#1066)



* [feat] add random_sparse_mask api

* correct test skip
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 19033c32
......@@ -21,7 +21,7 @@ except ImportError:
from .repo import Repo
from .signal_sparsity import Algo, SignalSparsity
from .signal_sparsity import Algo, SignalSparsity, random_sparse_mask
from .signal_sparsity_profiling import EnergyConcentrationProfile
from .version import __version_tuple__
......
......@@ -426,3 +426,21 @@ class SignalSparsity:
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst
def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor:
"""Get a random sparse mask
Args:
dense (Tensor):
Input dense tensor (no zeros).
percent (float):
Percent of non-zeros.
dim (int):
Dimension on which the random sparse mask is computed.
"""
assert percent > 0, percent
rand = torch.rand_like(dense)
ones = torch.ones_like(dense)
k = _get_k_for_topk(percent, None, dense.shape[dim])
return _scatter_topk_to_sparse_tensor(rand, ones, k, dim)
......@@ -7,7 +7,7 @@ import pytest
import torch
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity, random_sparse_mask
# Our own tolerance
ATOL = 1e-6
......@@ -427,3 +427,19 @@ def test_dst_disabled():
objects_are_equal(rt, result_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, result_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
assert dst is None
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_random_sparse_mask(device):
"""Tests random_sparse_mask API."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("no GPU")
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device)
mask = random_sparse_mask(dense, 0.2, 0)
assert mask.sum() == 1
for d in [0, 1]:
dense = torch.rand(100, 100).to(device)
mask = random_sparse_mask(dense, 0.01, d)
assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True)
assert mask.sum() == 100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment