Unverified Commit 3cc7fa8d authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support optional SST and DST (#1063)



* [feat] support sst disabled and dst disabled cases

* added tests
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 15d4cf15
...@@ -81,6 +81,9 @@ def _is_sparsity_zero( ...@@ -81,6 +81,9 @@ def _is_sparsity_zero(
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim """Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False. for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
""" """
if topk_percent is None and topk_element is None:
return False # 100% sparse
top_k_total_size = _top_k_total_size(dense, top_k_dim) top_k_total_size = _top_k_total_size(dense, top_k_dim)
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size) k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
return k == top_k_total_size return k == top_k_total_size
...@@ -245,11 +248,20 @@ class SignalSparsity: ...@@ -245,11 +248,20 @@ class SignalSparsity:
self._dst_top_k_percent = dst_top_k_percent self._dst_top_k_percent = dst_top_k_percent
self._validate_conf() self._validate_conf()
# TODO (Min): Type checking for the following
self._transform, self._inverse_transform = ( self._transform, self._inverse_transform = (
(_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform) (_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform)
) )
@property
def _sst_enabled(self) -> bool:
"""True if SST is enabled."""
return self._sst_top_k_element is not None or self._sst_top_k_percent is not None
@property
def _dst_enabled(self) -> bool:
"""True if DST is enabled."""
return self._dst_top_k_element is not None or self._dst_top_k_percent is not None
def _validate_conf(self) -> None: def _validate_conf(self) -> None:
"""Validating if the config is valid. """Validating if the config is valid.
...@@ -262,16 +274,14 @@ class SignalSparsity: ...@@ -262,16 +274,14 @@ class SignalSparsity:
If validation fails. If validation fails.
""" """
# assert that both top_k_elements and top_k_percent aren't set for sst and dst # assert that both top_k_elements and top_k_percent aren't set for sst and dst
def one_and_only(a: Optional[int], b: Optional[float]) -> bool: def both_set(a: Optional[int], b: Optional[float]) -> bool:
return (a is None) ^ (b is None) return (a is not None) and (b is not None)
if not ( if both_set(self._sst_top_k_element, self._sst_top_k_percent) or both_set(
one_and_only(self._sst_top_k_element, self._sst_top_k_percent) self._dst_top_k_element, self._dst_top_k_percent
and one_and_only(self._dst_top_k_element, self._dst_top_k_percent)
): ):
raise ValueError( raise ValueError(
"One and only one of top_k_element and top_k_percent for " "top_k_element and top_k_percent can't be both set\n"
"each of sst and dst must be provided as an argument.\n"
f"Input values are: sst element={self._sst_top_k_element}, sst percent={self._sst_top_k_percent}, " f"Input values are: sst element={self._sst_top_k_element}, sst percent={self._sst_top_k_percent}, "
f"dst element={self._dst_top_k_element}, dst percent={self._dst_top_k_percent}" f"dst element={self._dst_top_k_element}, dst percent={self._dst_top_k_percent}"
) )
...@@ -296,7 +306,7 @@ class SignalSparsity: ...@@ -296,7 +306,7 @@ class SignalSparsity:
f"and dst element={self._dst_top_k_element}" f"and dst element={self._dst_top_k_element}"
) )
def dense_to_sst(self, dense: Tensor) -> Tensor: def dense_to_sst(self, dense: Tensor) -> Optional[Tensor]:
"""Get Signal Sparse Tensor (SST) from a dense tensor """Get Signal Sparse Tensor (SST) from a dense tensor
Dense -> fft -> top-k -> results. Dense -> fft -> top-k -> results.
...@@ -310,10 +320,14 @@ class SignalSparsity: ...@@ -310,10 +320,14 @@ class SignalSparsity:
Input dense tensor (no zeros). Input dense tensor (no zeros).
Returns: Returns:
(Tensor): (Tensor, optional):
Same shaped tensor as the input dense tensor, still in dense format but in frequency Same shaped tensor as the input dense tensor, still in dense format but in frequency
domain (complex valued) and has zeros. domain (complex valued) and has zeros.
""" """
if not self._sst_enabled:
# Special case, SST is simply None, which represents an all-zero tensor.
return None
top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim) top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size) k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
dense_freq = self._transform(dense, dim=self._sst_top_k_dim) dense_freq = self._transform(dense, dim=self._sst_top_k_dim)
...@@ -325,7 +339,7 @@ class SignalSparsity: ...@@ -325,7 +339,7 @@ class SignalSparsity:
real_dense_freq = dense_freq.real.abs() real_dense_freq = dense_freq.real.abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim) return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor: def dense_sst_to_dst(self, dense: Tensor, sst: Optional[Tensor]) -> Optional[Tensor]:
"""Calculates DST from input dense and SST tensors. """Calculates DST from input dense and SST tensors.
dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
...@@ -340,6 +354,13 @@ class SignalSparsity: ...@@ -340,6 +354,13 @@ class SignalSparsity:
(Tensor): (Tensor):
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values. Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
""" """
if not self._dst_enabled:
# Special case, DST is simply None, which represents an all-zero tensor.
return None
if sst is None:
sst = torch.zeros_like(dense, dtype=torch.complex64)
if not (dense.shape == sst.shape): if not (dense.shape == sst.shape):
raise ValueError("dense and sst have different shapes!") raise ValueError("dense and sst have different shapes!")
...@@ -349,7 +370,7 @@ class SignalSparsity: ...@@ -349,7 +370,7 @@ class SignalSparsity:
del dense del dense
return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim) return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim)
def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor: def sst_dst_to_dense(self, sst: Optional[Tensor], dst: Optional[Tensor] = None) -> Tensor:
"""From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns """From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
the inverse transform of the SST tensor. the inverse transform of the SST tensor.
...@@ -363,12 +384,19 @@ class SignalSparsity: ...@@ -363,12 +384,19 @@ class SignalSparsity:
(Tensor): (Tensor):
A dense tensor in real number domain from the SST. A dense tensor in real number domain from the SST.
""" """
assert not (sst is None and dst is None), "both-None-case is not useful"
if sst is None:
# Simply the delta is the reconstruction.
return dst
# Now, ifft and then add the delta.
dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim)) dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim))
if dst is not None: if dst is not None:
dense_rt += dst dense_rt += dst
return dense_rt return dense_rt
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST """From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
the input dense tensor as the reconstruction. the input dense tensor as the reconstruction.
...@@ -393,6 +421,8 @@ class SignalSparsity: ...@@ -393,6 +421,8 @@ class SignalSparsity:
# of the same size as dense. # of the same size as dense.
return dense, None, dense return dense, None, dense
else: else:
# depending on whether self._sst_enabled and self._dst_enabled, None SST/DST tensors can be returned
# below as well.
sst = self.dense_to_sst(dense) sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst) dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst return self.sst_dst_to_dense(sst, dst), sst, dst
...@@ -57,18 +57,15 @@ def get_valid_conf_arg_list(): ...@@ -57,18 +57,15 @@ def get_valid_conf_arg_list():
return dict(zip(arg_key_list, vals_list)) return dict(zip(arg_key_list, vals_list))
# Validate value error is raised when, either: # Validate value error is raised when, either:
# 1. One and only one of sst (or dst) percent and element is not provided a value (not None). # 1. both sst (or dst) percent and element is not provided a value (not None).
# 2. Both of sst (or dst) percent and element is set to None. # 2. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
# 3. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
element = 10 element = 10
percent = 50 percent = 50
dim = 0 dim = 0
args_list = [ args_list = [
[element, percent, dim, element, None, dim], # case 1. [element, percent, dim, element, None, dim], # case 1.
[element, None, dim, element, percent, dim], [element, None, dim, element, percent, dim],
[None, None, dim, element, None, dim], # case 2. [0, None, dim, None, None, dim], # case 2.
[element, None, dim, None, None, dim],
[0, None, dim, None, None, dim], # case 3.
[None, 0, dim, None, None, dim], [None, 0, dim, None, None, dim],
[element, None, dim, 0, None, dim], [element, None, dim, 0, None, dim],
[element, None, dim, None, 0, dim], [element, None, dim, None, 0, dim],
...@@ -399,3 +396,34 @@ def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent, device): ...@@ -399,3 +396,34 @@ def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent, device):
objects_are_equal(lossy_dense.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL) objects_are_equal(lossy_dense.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, None, raise_exception=True, rtol=RTOL, atol=ATOL) objects_are_equal(sst, None, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL) objects_are_equal(dst.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
def test_sst_disabled():
"""Tests the case where SST is disabled."""
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000])
result = torch.tensor([0.0, 0.0, 0.7000, 0.8000])
sparser = SignalSparsity(dst_top_k_element=2, dst_top_k_dim=0)
rt, sst, dst = sparser.lossy_compress(dense)
objects_are_equal(rt, result, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst, result, raise_exception=True, rtol=RTOL, atol=ATOL)
assert sst is None
def test_dst_disabled():
"""Tests the case where DST is disabled."""
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000])
result_rt = torch.tensor([0.6000, 0.7618, 0.7000, 0.6382, 0.8000])
result_sst = torch.tensor(
[
3.50000000000000000000 + 0.00000000000000000000j,
0.00000000000000000000 + 0.00000000000000000000j,
-0.25000002980232238770 + 0.08122986555099487305j,
-0.25000002980232238770 - 0.08122986555099487305j,
0.00000000000000000000 + 0.00000000000000000000j,
]
)
sparser = SignalSparsity(sst_top_k_element=3, sst_top_k_dim=0)
rt, sst, dst = sparser.lossy_compress(dense)
objects_are_equal(rt, result_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, result_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
assert dst is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment