Unverified Commit 5c60f33c authored by Riyasat Ohib's avatar Riyasat Ohib Committed by GitHub
Browse files

implementation of lossy_compression method (#1051)

* [Feat] implements lossy_compress with tests

1. Implements a method lossy_compress that takes in a dense tensor and returns a reconstruction with sst and dst, and optionally with sparsity.
parent c1dada48
......@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
from dataclasses import dataclass
from enum import Enum
import json
......@@ -276,6 +277,7 @@ class Repo:
return element
state_dict = torch.load(file_path)
ret_state_dict = copy.deepcopy(state_dict) # This is only a temporary addition for testing.
_recursive_apply_to_elements(state_dict, fn, [])
file_path_or_state_dict = state_dict
......
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Optional
from typing import Optional, Tuple
import torch
from torch import Tensor
......@@ -75,6 +75,17 @@ def _top_k_total_size(tensor: Tensor, topk_dim: Optional[int]) -> int:
return top_k_total_size
def _is_sparsity_zero(
dense: Tensor, topk_percent: Optional[float], topk_element: Optional[int], top_k_dim: Optional[int]
) -> bool:
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
"""
top_k_total_size = _top_k_total_size(dense, top_k_dim)
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
return k == top_k_total_size
def _dct_transform(dense: Tensor) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor.
......@@ -289,6 +300,35 @@ class SignalSparsity:
dense_rt += dst
return dense_rt
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
the input dense tensor as the reconstruction.
Args:
dense (Tensor):
Input dense tensor (no zeros).
Returns:
(Tuple[Tensor, Tensor, Tensor]):
A tuple of the form (lossy_reconstruction, sst, dst) with three tensors of the same
shape as the dense tensor.
"""
if _is_sparsity_zero(
dense, self._sst_top_k_percent, self._sst_top_k_element, self._sst_top_k_dim
) and _is_sparsity_zero(dense, self._dst_top_k_percent, self._dst_top_k_element, self._dst_top_k_dim):
# when sparsity is 0% for both sst and dst, the dense tensor itself is returned as the reconstructed
# tensor, sst is returned as None and dst as the dense tensor. This choice is made because with the
# returned sst=None and dst=dense, we should be able to recombine them if needed to retrieve the
# dense tensor again as: dense = inv_transform(sst) + dst, where inv_transform(sst=None) = zero_tensor
# of the same size as dense.
return dense, None, dense
else:
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst
# We could separate have helper functions that work on state_dict instead of a tensor.
# One option is to extend the above class to handle state_dict as well as tensor
......
......@@ -243,3 +243,32 @@ def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True)
@pytest.mark.parametrize("tensor, expd_sst, expd_dst, expd_rt, dim, unused, k", get_test_params())
def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k):
"""Tests the lossy_compress method against expected sst, dst and reconstruced tensor."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, expd_rt, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True)
objects_are_equal(dst, expd_dst, raise_exception=True)
@pytest.mark.parametrize(
"tensor, dim, top_k_percent",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 0, 100),
(torch.linspace(-0.01, 0.06, 42).reshape(7, 6), 0, 100),
(torch.linspace(-10, 15, 36).reshape(6, 6), 1, 100),
],
)
def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent):
"""Tests whether lossy_compress method simply returns dense tensor when sparsity is 0."""
sparser = SignalSparsity(
sst_top_k_percent=top_k_percent, sst_top_k_dim=dim, dst_top_k_percent=top_k_percent, dst_top_k_dim=dim
)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, tensor, raise_exception=True)
objects_are_equal(sst, None, raise_exception=True)
objects_are_equal(dst, tensor, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment