Unverified Commit c1dada48 authored by Riyasat Ohib's avatar Riyasat Ohib Committed by GitHub
Browse files

Implmentation of dense_sst_to_dst and sst_dst_to_dense (#1048)

[Feat] Implements dense_sst_to_dst and sst_dst_to_dense methods and adds tests

1. Implements the dense_sst_to_dst and sst_dst_to_dense method.
2. Adds tests for perfect reconstruction with all top-k across different dims.
3. Adds tests for the two new methods.
parent d3bda798
...@@ -88,6 +88,19 @@ def _dct_transform(dense: Tensor) -> Tensor: ...@@ -88,6 +88,19 @@ def _dct_transform(dense: Tensor) -> Tensor:
raise NotImplementedError("Support for DCT has not been implemented yet!") raise NotImplementedError("Support for DCT has not been implemented yet!")
def _inverse_dct_transform(sst: Tensor) -> Tensor:
"""Should take a tensor and perform an inverse Discrete Cosine Transform and return a new tensor.
Args:
sst (Tensor):
Input sst tensor (may have zeros) in frequency domain.
Returns:
(Tensor):
A new, transformed dense tensor with real domain values.
"""
raise NotImplementedError("Support for iDCT has not been implemented yet!")
class Algo(Enum): class Algo(Enum):
FFT = 0 FFT = 0
DCT = 1 DCT = 1
...@@ -156,7 +169,7 @@ class SignalSparsity: ...@@ -156,7 +169,7 @@ class SignalSparsity:
self._validate_conf() self._validate_conf()
# TODO (Min): Type checking for the following # TODO (Min): Type checking for the following
self._transform = torch.fft.fft if algo is Algo.FFT else _dct_transform # type: ignore self._transform, self._inverse_transform = (torch.fft.fft, torch.fft.ifft) if algo is Algo.FFT else (_dct_transform, _inverse_dct_transform) # type: ignore
def _validate_conf(self) -> None: def _validate_conf(self) -> None:
"""Validating if the config is valid. """Validating if the config is valid.
...@@ -230,15 +243,13 @@ class SignalSparsity: ...@@ -230,15 +243,13 @@ class SignalSparsity:
# or DCT transformed components when using DCT (currently not implemented). # or DCT transformed components when using DCT (currently not implemented).
# TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be # TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be
# done on the smaller phases. # done on the smaller phases.
real_dense_freq = dense_freq.real.abs() real_dense_freq = torch.real(dense_freq).abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim) return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor: def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
"""From dense and SST to a DST """Calculates DST from input dense and SST tensors.
This will use sst_dst_to_dense below but with dst=None.
dense - ifft(sst)[using sst_dst_to_dense below) -> top-k -> result dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
Args: Args:
dense (Tensor): dense (Tensor):
...@@ -250,32 +261,33 @@ class SignalSparsity: ...@@ -250,32 +261,33 @@ class SignalSparsity:
(Tensor): (Tensor):
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values. Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
""" """
pass if not (dense.shape == sst.shape):
raise ValueError("dense and sst have different shapes!")
def sst_dst_to_dense(self, sst: Tensor, dst: Tensor = None) -> Tensor: top_k_total_size = _top_k_total_size(dense, self._dst_top_k_dim)
"""From SST and dst back to a dense k = _get_k_for_topk(self._dst_top_k_percent, self._dst_top_k_element, top_k_total_size)
delta = dense - self.sst_dst_to_dense(sst) # sst_dst_to_dense(sst) returns the inverse transform here
del dense
return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim)
result = ifft(sst) def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
if dst is not None: """From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
result += dst the inverse transform of the SST tensor.
return result
Args: Args:
sst (Tensor): sst (Tensor):
Singal sparse tensor. Required argument. Singal sparse tensor. Required argument.
dst (Tensor, optinoal): dst (Tensor, optional):
Delta sparse tensor, optional. Delta sparse tensor, optional.
Returns: Returns:
(Tensor): (Tensor):
A dense tensor in real number domain from the SST. A dense tensor in real number domain from the SST.
""" """
pass dense_rt = torch.real(self._inverse_transform(sst))
if dst is not None:
def sst_or_dst_to_mask(self) -> None: dense_rt += dst
# we shouldn't need this function since going from SST/DST to mask should be a return dense_rt
# trivial call in pytorch. Maybe I am missing something.
pass
# We could separate have helper functions that work on state_dict instead of a tensor. # We could separate have helper functions that work on state_dict instead of a tensor.
......
...@@ -12,16 +12,19 @@ from fairscale.experimental.wgit.signal_sparsity import SignalSparsity ...@@ -12,16 +12,19 @@ from fairscale.experimental.wgit.signal_sparsity import SignalSparsity
def get_test_params(): def get_test_params():
"""Helper function to create and return a list of tuples of the form: """Helper function to create and return a list of tuples of the form:
(in_tensor, expected_tensor, dim, percent, top_k_element) to be used as parameters for tests. (dense, expected_sst, expected_dst, expected_reconstructed_tensor (RT), dim, percent, top_k_element)
to be used as parameters for tests.
""" """
# input in_tensors # input in_tensors
tensor_4x3 = torch.arange(12).reshape(4, 3) tensor_4x3_None = torch.arange(12).reshape(4, 3).float()
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2) tensor_4x3_0 = torch.arange(50, 62).reshape(4, 3) / 100
tensor_3x3_1 = torch.linspace(-5, 5, 9).reshape(3, 3)
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2).float()
# Expected SST output tensors for 4x3 tensor of ascending ints # with dim=None, top-2
expected_4x3_None = torch.tensor( expd_sst_4x3_None = torch.tensor(
[ [
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], # with dim=None, top-2 [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[21.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], [21.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[30.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], [30.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
...@@ -29,40 +32,94 @@ def get_test_params(): ...@@ -29,40 +32,94 @@ def get_test_params():
dtype=torch.complex64, dtype=torch.complex64,
) )
expected_4x3_0 = torch.tensor( # with dim=None, top-2
expd_dst_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
)
# expected_reconstructed_tensor with dim=None and top-2 for both sst and dst
expd_rt_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [7.0, 7.0, 7.0], [10.0, 10.0, 10.0]], dtype=torch.float32
)
# with dim=0, top-2
expd_sst_4x3_0 = torch.tensor(
[ [
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j], # with dim=0, top-2 [0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j], [0.0000000000 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j], [1.7100000381 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j], [1.7999999523 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
], ],
dtype=torch.complex64, dtype=torch.complex64,
) )
expected_4x3_1 = torch.tensor( # with dim=0, top-2
expd_dst_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5400, 0.5400, 0.5400], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]]
)
# expected_reconstructed_tensor with dim=0 and top-2 for both sst and dst
expd_rt_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5300, 0.5400, 0.5500], [0.5700, 0.5700, 0.5700], [0.5900, 0.6000, 0.6100]]
)
# with dim=1, top-2
expd_sst_3x3_1 = torch.tensor(
[ [
[3.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j], # with dim=1, top-2 [-11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
[12.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j], [0.0000000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, -1.8750000000 - 1.0825316906j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j], [11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
], ],
dtype=torch.complex64, dtype=torch.complex64,
) )
expected_2x2x3_1 = torch.tensor( # with dim=1, top-2
expd_dst_3x3_1 = torch.tensor(
[
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
[0.0000000000e00, -4.8244856998e-08, 0.0000000000e00],
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
]
)
# expected_reconstructed_tensor with dim=1 and top-2 for both sst and dst
expd_rt_3x3_1 = torch.tensor([[-5.0000, -3.7500, -2.5000], [-1.2500, 0.0000, 1.2500], [2.5000, 3.7500, 5.0000]])
# with dim=1, top-1
expd_sst_2x2x3_1 = torch.tensor(
[ [
[[1.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, -1.0 + 0.0j]], # with dim=1, top-2 [[0.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, 0.0 + 0.0j]],
[[9.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, -1.0 + 0.0j]], [[0.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, 0.0 + 0.0j]],
[[17.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, -1.0 + 0.0j]], [[0.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, 0.0 + 0.0j]],
], ],
dtype=torch.complex64, dtype=torch.complex64,
) )
# with dim=1, top-1
expd_dst_2x2x3_1 = torch.tensor(
[
[[0.5000, 0.5000], [0.0000, 0.0000]],
[[4.5000, 4.5000], [0.0000, 0.0000]],
[[8.5000, 8.5000], [0.0000, 0.0000]],
],
dtype=torch.float32,
)
# expected_reconstructed_tensor with dim=1 and top-1 for both sst and dst
expd_rt_2x2x3_1 = torch.tensor(
[
[[0.0000, 1.0000], [2.5000, 2.5000]],
[[4.0000, 5.0000], [6.5000, 6.5000]],
[[8.0000, 9.0000], [10.5000, 10.5000]],
],
dtype=torch.float32,
)
return [ return [
(tensor_4x3, expected_4x3_None, None, 20, 2), (tensor_4x3_None, expd_sst_4x3_None, expd_dst_4x3_None, expd_rt_4x3_None, None, 20, 2),
(tensor_4x3, expected_4x3_0, 0, 50, 2), (tensor_4x3_0, expd_sst_4x3_0, expd_dst_4x3_0, expd_rt_4x3_0, 0, 50, 2),
(tensor_4x3, expected_4x3_1, 1, 70, 2), (tensor_3x3_1, expd_sst_3x3_1, expd_dst_3x3_1, expd_rt_3x3_1, 1, 70, 2),
(tensor_2x2x3, expected_2x2x3_1, 1, 100, 2), (tensor_2x2x3, expd_sst_2x2x3_1, expd_dst_2x2x3_1, expd_rt_2x2x3_1, 1, 50, 1),
] ]
...@@ -128,16 +185,16 @@ def test_dense_to_sst_perfect_recons(tensor, dim): ...@@ -128,16 +185,16 @@ def test_dense_to_sst_perfect_recons(tensor, dim):
assert all((sparser_2d.dense_to_sst(tensor) == torch.fft.fft(tensor)).flatten()) assert all((sparser_2d.dense_to_sst(tensor) == torch.fft.fft(tensor)).flatten())
@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params()) @pytest.mark.parametrize("tensor, expd_sst, unused1, unused2, dim, unused3, k", get_test_params())
def test_dense_to_sst_fixed(tensor, expected, dim, percent, k): def test_dense_to_sst_fixed(tensor, expd_sst, unused1, unused2, dim, unused3, k):
"""Tests for fixed input dense tensor and fixed expected output SST tensor for top-2 elements.""" """Tests for fixed input dense tensor and fixed expected output SST tensor."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100) sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst = sparser_2d.dense_to_sst(tensor) sst = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst, expected, raise_exception=True) objects_are_equal(sst, expd_sst, raise_exception=True)
@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params()) @pytest.mark.parametrize("tensor, unused1, unused2, unused3, dim, percent, k", get_test_params())
def test_percent_element(tensor, expected, dim, percent, k): def test_percent_element(tensor, unused1, unused2, unused3, dim, percent, k):
"""Tests whether comparative values for top_k_element and top_k_percent returns same outputs""" """Tests whether comparative values for top_k_element and top_k_percent returns same outputs"""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100) sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst_element = sparser_2d.dense_to_sst(tensor) sst_element = sparser_2d.dense_to_sst(tensor)
...@@ -147,3 +204,42 @@ def test_percent_element(tensor, expected, dim, percent, k): ...@@ -147,3 +204,42 @@ def test_percent_element(tensor, expected, dim, percent, k):
) )
sst_percent = sparser_2d.dense_to_sst(tensor) sst_percent = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst_element, sst_percent, raise_exception=True) objects_are_equal(sst_element, sst_percent, raise_exception=True)
@pytest.mark.parametrize("tensor, sst, expd_dst, unused1, dim, unused2, k", get_test_params())
def test_dense_sst_to_dst(tensor, sst, expd_dst, unused1, dim, unused2, k):
"""Tests fixed expected output DST tensor with fixed input dense and SST tensors."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, dst_top_k_element=k, dst_top_k_dim=dim)
dst = sparser_2d.dense_sst_to_dst(tensor, sst)
objects_are_equal(dst, expd_dst, raise_exception=True)
@pytest.mark.parametrize(
"dense, k, dim",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 40, None), # top-40, dim=None
(torch.linspace(0.1, 0.6, 30).reshape(5, 6), 5, 0), # top-5, dim=0
(torch.linspace(-0.1, 0.6, 35).reshape(7, 5), 5, 1), # top-5, dim=1
(torch.arange(60).float().reshape(10, 6), 60, None), # top-60, dim=None
(torch.arange(60).float().reshape(10, 6), 10, 0), # top-10, dim=0
(torch.arange(60).float().reshape(10, 6), 6, 1), # top-6, dim=1
(torch.arange(60).float().reshape(2, 5, 6), 5, 1), # top-5, dim=1
],
)
def test_sst_dst_to_perfect_dense_reconstruction(dense, k, dim):
"""Tests whether perfect reconstruction of input dense tensor is generated when top-k matches the numel
across some dimension dim for both SST and DST.
"""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
sst = sparser.dense_to_sst(dense)
dst = sparser.dense_sst_to_dst(dense, sst)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense, dense_recons, raise_exception=True)
@pytest.mark.parametrize("unused1, sst, dst, expd_rt, dim, unused2, k", get_test_params())
def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
"""Tests the correct expected reconstruction from frozen sst and dst tensors."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment