Unverified Commit 5c60f33c authored by Riyasat Ohib's avatar Riyasat Ohib Committed by GitHub
Browse files

implementation of lossy_compression method (#1051)

* [Feat] implements lossy_compress with tests

1. Implements a method lossy_compress that takes in a dense tensor and returns a reconstruction with sst and dst, and optionally with sparsity.
parent c1dada48
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import json import json
...@@ -276,6 +277,7 @@ class Repo: ...@@ -276,6 +277,7 @@ class Repo:
return element return element
state_dict = torch.load(file_path) state_dict = torch.load(file_path)
ret_state_dict = copy.deepcopy(state_dict) # This is only a temporary addition for testing.
_recursive_apply_to_elements(state_dict, fn, []) _recursive_apply_to_elements(state_dict, fn, [])
file_path_or_state_dict = state_dict file_path_or_state_dict = state_dict
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -75,6 +75,17 @@ def _top_k_total_size(tensor: Tensor, topk_dim: Optional[int]) -> int: ...@@ -75,6 +75,17 @@ def _top_k_total_size(tensor: Tensor, topk_dim: Optional[int]) -> int:
return top_k_total_size return top_k_total_size
def _is_sparsity_zero(
dense: Tensor, topk_percent: Optional[float], topk_element: Optional[int], top_k_dim: Optional[int]
) -> bool:
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
"""
top_k_total_size = _top_k_total_size(dense, top_k_dim)
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
return k == top_k_total_size
def _dct_transform(dense: Tensor) -> Tensor: def _dct_transform(dense: Tensor) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor. """Should take a tensor and perform a Discrete Cosine Transform on the tensor.
...@@ -289,6 +300,35 @@ class SignalSparsity: ...@@ -289,6 +300,35 @@ class SignalSparsity:
dense_rt += dst dense_rt += dst
return dense_rt return dense_rt
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
the input dense tensor as the reconstruction.
Args:
dense (Tensor):
Input dense tensor (no zeros).
Returns:
(Tuple[Tensor, Tensor, Tensor]):
A tuple of the form (lossy_reconstruction, sst, dst) with three tensors of the same
shape as the dense tensor.
"""
if _is_sparsity_zero(
dense, self._sst_top_k_percent, self._sst_top_k_element, self._sst_top_k_dim
) and _is_sparsity_zero(dense, self._dst_top_k_percent, self._dst_top_k_element, self._dst_top_k_dim):
# when sparsity is 0% for both sst and dst, the dense tensor itself is returned as the reconstructed
# tensor, sst is returned as None and dst as the dense tensor. This choice is made because with the
# returned sst=None and dst=dense, we should be able to recombine them if needed to retrieve the
# dense tensor again as: dense = inv_transform(sst) + dst, where inv_transform(sst=None) = zero_tensor
# of the same size as dense.
return dense, None, dense
else:
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst
# We could separate have helper functions that work on state_dict instead of a tensor. # We could separate have helper functions that work on state_dict instead of a tensor.
# One option is to extend the above class to handle state_dict as well as tensor # One option is to extend the above class to handle state_dict as well as tensor
......
...@@ -243,3 +243,32 @@ def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k): ...@@ -243,3 +243,32 @@ def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim) sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst) dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True) objects_are_equal(dense_recons, expd_rt, raise_exception=True)
@pytest.mark.parametrize("tensor, expd_sst, expd_dst, expd_rt, dim, unused, k", get_test_params())
def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k):
"""Tests the lossy_compress method against expected sst, dst and reconstruced tensor."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, expd_rt, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True)
objects_are_equal(dst, expd_dst, raise_exception=True)
@pytest.mark.parametrize(
"tensor, dim, top_k_percent",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 0, 100),
(torch.linspace(-0.01, 0.06, 42).reshape(7, 6), 0, 100),
(torch.linspace(-10, 15, 36).reshape(6, 6), 1, 100),
],
)
def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent):
"""Tests whether lossy_compress method simply returns dense tensor when sparsity is 0."""
sparser = SignalSparsity(
sst_top_k_percent=top_k_percent, sst_top_k_dim=dim, dst_top_k_percent=top_k_percent, dst_top_k_dim=dim
)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, tensor, raise_exception=True)
objects_are_equal(sst, None, raise_exception=True)
objects_are_equal(dst, tensor, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment