Unverified Commit 4c830de1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] bugs in signal sparsity class and improving tests (#1058)



* update examples and comment

* fixed issue with fft/ifft only doing the last dim

* fixed a int/round bug; fixed tests

* add cuda tests

* add atol and rtol

* skip cuda test correctly
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f81a60be
......@@ -477,7 +477,14 @@ class GPT2(Base):
return self.clf_head(h), logits
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
def objects_are_equal(
a: Any,
b: Any,
raise_exception: bool = False,
dict_key: Optional[str] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
......@@ -515,9 +522,9 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return False
# assert_close.
if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b)
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
else:
torch.testing.assert_close(a, b)
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
......
......@@ -17,7 +17,7 @@ def _get_k_for_topk(topk_percent: Optional[float], top_k_element: Optional[int],
simply returns the value for k. Also, ensures k is never 0 to avoid all-zero tensors.
"""
if top_k_element is None:
top_k_element = int(top_k_total_size * topk_percent / 100.0)
top_k_element = round(top_k_total_size * topk_percent / 100.0)
elif top_k_element > top_k_total_size:
raise ValueError("top_k_element for sst or dst is larger than max number of elements along top_k_dim")
# ensure we never have 100% sparsity in tensor and always have 1 surviving element!
......@@ -86,12 +86,73 @@ def _is_sparsity_zero(
return k == top_k_total_size
def _dct_transform(dense: Tensor) -> Tensor:
def _fft_transform(dense: Tensor, dim: int) -> Tensor:
"""Wrapper of torch.fft.fft with more flexibility on dimensions.
TODO (Min): figure out if we need to change other args like frequency length, n, or
the normalization flag.
For our use case, we use fft not rfft since we want big magnitute components from
both positive and negative frequencies.
Args:
dense (Tensor):
Input dense tensor (no zeros).
dim (int):
Which dimension to transform.
Returns:
(Tensor, complex):
transformed dense tensor FFT components.
"""
orig_shape = None
if dim is None:
orig_shape = dense.shape
dense = dense.reshape(-1)
dim = -1
ret = torch.fft.fft(dense, dim=dim)
if orig_shape is not None:
ret = ret.reshape(orig_shape)
return ret
def _ifft_transform(sst: Tensor, dim: int) -> Tensor:
"""Wrapper of torch.fft.ifft with more flexibility on dimensions.
Args:
sst (Tensor):
Input sst tensor (may have zeros) in frequency domain.
dim (int):
Which dimension to transform.
Returns:
(Tensor):
A new, transformed dense tensor with real domain values.
"""
assert sst.is_complex()
orig_shape = None
if dim is None:
orig_shape = sst.shape
sst = sst.reshape(-1)
dim = -1
ret = torch.fft.ifft(sst, dim=dim)
if orig_shape is not None:
ret = ret.reshape(orig_shape)
return ret
def _dct_transform(dense: Tensor, dim: int) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor.
Args:
dense (Tensor):
Input dense tensor (no zeros).
dim (int):
Which dimension to transform.
Returns:
(Tensor):
transformed dense tensor DCT components
......@@ -99,12 +160,14 @@ def _dct_transform(dense: Tensor) -> Tensor:
raise NotImplementedError("Support for DCT has not been implemented yet!")
def _inverse_dct_transform(sst: Tensor) -> Tensor:
def _idct_transform(sst: Tensor, dim: int) -> Tensor:
"""Should take a tensor and perform an inverse Discrete Cosine Transform and return a new tensor.
Args:
sst (Tensor):
Input sst tensor (may have zeros) in frequency domain.
dim (int):
Which dimension to transform.
Returns:
(Tensor):
A new, transformed dense tensor with real domain values.
......@@ -128,6 +191,9 @@ class SignalSparsity:
`sst_top_k_element` or `sst_top_k_percent` and also requires a
value for one of `dst_top_k_element` or `dst_top_k_percent`.
This class only handles tensor inputs and outputs. We leave
state_dict type of data handling to upper layer functions.
Args:
algo (Algo):
The algorithm used. Default: FFT
......@@ -153,11 +219,11 @@ class SignalSparsity:
Example:
.. code-block:: python
2d_sparser = SignalSparsity()
sst, dst = 2d_sparser.get_sst_dst(linear.weight.data)
2d_sparser = SignalSparsity(sst_top_k_element=10, dst_top_k_element=1)
sst = 2d_sparser.dense_to_sst(linear.weight.data)
3d_sparser = SingalSparsity(algo=Algo.DCT, sst_top_k_dim=None, dst_top_k_dim=-1, sst_top_k_percent=10, dst_top_k_element=100)
conv.weight.data = 3d_sparser.get_sst_dst_weight(conv.weight.data)
3d_sparser = SingalSparsity(algo=Algo.FFT, sst_top_k_dim=None, dst_top_k_dim=-1, sst_top_k_percent=10, dst_top_k_element=100)
conv.weight.data, _, _ = 3d_sparser.lossy_compress(conv.weight.data)
"""
def __init__(
......@@ -180,7 +246,9 @@ class SignalSparsity:
self._validate_conf()
# TODO (Min): Type checking for the following
self._transform, self._inverse_transform = (torch.fft.fft, torch.fft.ifft) if algo is Algo.FFT else (_dct_transform, _inverse_dct_transform) # type: ignore
self._transform, self._inverse_transform = (
(_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform)
)
def _validate_conf(self) -> None:
"""Validating if the config is valid.
......@@ -248,13 +316,13 @@ class SignalSparsity:
"""
top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
dense_freq = self._transform(dense)
dense_freq = self._transform(dense, dim=self._sst_top_k_dim)
# NOTE: real_dense_freq can potentially be magnitude of complex frequency components
# or DCT transformed components when using DCT (currently not implemented).
# TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be
# done on the smaller phases.
real_dense_freq = torch.real(dense_freq).abs()
real_dense_freq = dense_freq.real.abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
......@@ -295,7 +363,7 @@ class SignalSparsity:
(Tensor):
A dense tensor in real number domain from the SST.
"""
dense_rt = torch.real(self._inverse_transform(sst))
dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim))
if dst is not None:
dense_rt += dst
return dense_rt
......@@ -328,9 +396,3 @@ class SignalSparsity:
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst
# We could separate have helper functions that work on state_dict instead of a tensor.
# One option is to extend the above class to handle state_dict as well as tensor
# but we may want to filter on the keys in the state_dict, so maybe that option isn't
# the best. We need to have further discussions on this.
......@@ -30,6 +30,7 @@ from . import nn as nn
from . import testing as testing
from . import utils as utils
from . import jit as jit
from . import fft as fft
#MODIFIED BY TORCHGPIPE
from . import backends
......@@ -117,6 +118,7 @@ class Tensor:
grad: Optional[Tensor] = ...
data: Tensor = ...
names: List[str] = ...
real: Tensor = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ...
......@@ -395,7 +397,6 @@ class Tensor:
def expm1(self) -> Tensor: ...
def expm1_(self) -> Tensor: ...
def exponential_(self, lambd: _float=1, *, generator: Generator=None) -> Tensor: ...
def fft(self, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
@overload
def fill_(self, value: Number) -> Tensor: ...
@overload
......@@ -453,7 +454,6 @@ class Tensor:
def half(self) -> Tensor: ...
def hardshrink(self, lambd: Number=0.5) -> Tensor: ...
def histc(self, bins: _int=100, min: Number=0, max: Number=0) -> Tensor: ...
def ifft(self, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
def imag(self) -> Tensor: ...
@overload
def index_add(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
......@@ -494,7 +494,6 @@ class Tensor:
def int(self) -> Tensor: ...
def int_repr(self) -> Tensor: ...
def inverse(self) -> Tensor: ...
def irfft(self, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True, signal_sizes: _size=()) -> Tensor: ...
def is_coalesced(self) -> _bool: ...
def is_complex(self) -> _bool: ...
def is_contiguous(self) -> _bool: ...
......@@ -683,7 +682,6 @@ class Tensor:
def random_(self, to: _int, *, generator: Generator=None) -> Tensor: ...
@overload
def random_(self, *, generator: Generator=None) -> Tensor: ...
def real(self) -> Tensor: ...
def reciprocal(self) -> Tensor: ...
def reciprocal_(self) -> Tensor: ...
def refine_names(self, names: List[Union[str, None]]) -> Tensor: ...
......@@ -720,7 +718,6 @@ class Tensor:
@overload
def resize_(self, *size: _int, memory_format: Optional[memory_format]=None) -> Tensor: ...
def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
def rfft(self, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True) -> Tensor: ...
def roll(self, shifts: Union[_int, _size], dims: Union[_int, _size]=()) -> Tensor: ...
def rot90(self, k: _int=1, dims: _size=(0,1)) -> Tensor: ...
def round(self) -> Tensor: ...
......@@ -1317,7 +1314,6 @@ def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_alpha_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ...
def fft(self: Tensor, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
@overload
def fill_(self: Tensor, value: Number) -> Tensor: ...
@overload
......@@ -1394,7 +1390,6 @@ def hann_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, lay
def hardshrink(self: Tensor, lambd: Number=0.5) -> Tensor: ...
def histc(self: Tensor, bins: _int=100, min: Number=0, max: Number=0, *, out: Optional[Tensor]=None) -> Tensor: ...
def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def ifft(self: Tensor, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
def imag(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload
def index_add(self: Tensor, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
......@@ -1421,7 +1416,6 @@ def index_select(self: Tensor, dim: Union[str, None], index: Tensor, *, out: Opt
def instance_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ...
def int_repr(self: Tensor) -> Tensor: ...
def inverse(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def irfft(self: Tensor, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True, signal_sizes: _size=()) -> Tensor: ...
def is_complex(self: Tensor) -> _bool: ...
def is_distributed(self: Tensor) -> _bool: ...
def is_floating_point(self: Tensor) -> _bool: ...
......@@ -1705,7 +1699,6 @@ def result_type(tensor: Tensor, other: Number) -> _dtype: ...
def result_type(scalar: Number, tensor: Tensor) -> _dtype: ...
@overload
def result_type(scalar1: Number, scalar2: Number) -> _dtype: ...
def rfft(self: Tensor, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True) -> Tensor: ...
@overload
def rnn_relu(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ...
@overload
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from torch import Tensor
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
#_int = builtins.int
#_float = builtins.float
#_bool = builtins.bool
#_size = Union[Size, List[int], Tuple[int, ...]]
def fft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def ifft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def rfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def irfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
......@@ -9,118 +9,35 @@ import torch
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity
# Our own tolerance
ATOL = 1e-6
RTOL = 1e-5
def get_test_params():
"""Helper function to create and return a list of tuples of the form:
(dense, expected_sst, expected_dst, expected_reconstructed_tensor (RT), dim, percent, top_k_element)
to be used as parameters for tests.
"""
# input in_tensors
tensor_4x3_None = torch.arange(12).reshape(4, 3).float()
tensor_4x3_0 = torch.arange(50, 62).reshape(4, 3) / 100
tensor_3x3_1 = torch.linspace(-5, 5, 9).reshape(3, 3)
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2).float()
# with dim=None, top-2
expd_sst_4x3_None = torch.tensor(
[
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[21.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[30.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
],
dtype=torch.complex64,
)
# with dim=None, top-2
expd_dst_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
)
# expected_reconstructed_tensor with dim=None and top-2 for both sst and dst
expd_rt_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [7.0, 7.0, 7.0], [10.0, 10.0, 10.0]], dtype=torch.float32
)
# with dim=0, top-2
expd_sst_4x3_0 = torch.tensor(
[
[0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[0.0000000000 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
[1.7100000381 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[1.7999999523 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
],
dtype=torch.complex64,
)
# with dim=0, top-2
expd_dst_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5400, 0.5400, 0.5400], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]]
)
# expected_reconstructed_tensor with dim=0 and top-2 for both sst and dst
expd_rt_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5300, 0.5400, 0.5500], [0.5700, 0.5700, 0.5700], [0.5900, 0.6000, 0.6100]]
)
# with dim=1, top-2
expd_sst_3x3_1 = torch.tensor(
[
[-11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
[0.0000000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, -1.8750000000 - 1.0825316906j],
[11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
],
dtype=torch.complex64,
)
# with dim=1, top-2
expd_dst_3x3_1 = torch.tensor(
[
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
[0.0000000000e00, -4.8244856998e-08, 0.0000000000e00],
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
]
)
# expected_reconstructed_tensor with dim=1 and top-2 for both sst and dst
expd_rt_3x3_1 = torch.tensor([[-5.0000, -3.7500, -2.5000], [-1.2500, 0.0000, 1.2500], [2.5000, 3.7500, 5.0000]])
# with dim=1, top-1
expd_sst_2x2x3_1 = torch.tensor(
[
[[0.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, 0.0 + 0.0j]],
[[0.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, 0.0 + 0.0j]],
[[0.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, 0.0 + 0.0j]],
],
dtype=torch.complex64,
)
# enable this for debugging.
# torch.set_printoptions(precision=20)
# with dim=1, top-1
expd_dst_2x2x3_1 = torch.tensor(
[
[[0.5000, 0.5000], [0.0000, 0.0000]],
[[4.5000, 4.5000], [0.0000, 0.0000]],
[[8.5000, 8.5000], [0.0000, 0.0000]],
],
dtype=torch.float32,
)
# expected_reconstructed_tensor with dim=1 and top-1 for both sst and dst
expd_rt_2x2x3_1 = torch.tensor(
[
[[0.0000, 1.0000], [2.5000, 2.5000]],
[[4.0000, 5.0000], [6.5000, 6.5000]],
[[8.0000, 9.0000], [10.5000, 10.5000]],
],
dtype=torch.float32,
)
return [
(tensor_4x3_None, expd_sst_4x3_None, expd_dst_4x3_None, expd_rt_4x3_None, None, 20, 2),
(tensor_4x3_0, expd_sst_4x3_0, expd_dst_4x3_0, expd_rt_4x3_0, 0, 50, 2),
(tensor_3x3_1, expd_sst_3x3_1, expd_dst_3x3_1, expd_rt_3x3_1, 1, 70, 2),
(tensor_2x2x3, expd_sst_2x2x3_1, expd_dst_2x2x3_1, expd_rt_2x2x3_1, 1, 50, 1),
]
@pytest.mark.parametrize(
"dense, k, dim",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 40, None), # top-40, dim=None
(torch.linspace(0.1, 0.6, 30).reshape(5, 6), 5, 0), # top-5, dim=0
(torch.linspace(-0.1, 0.6, 35).reshape(7, 5), 5, 1), # top-5, dim=1
(torch.arange(60).float().reshape(10, 6), 60, None), # top-60, dim=None
(torch.arange(60).float().reshape(10, 6), 10, 0), # top-10, dim=0
(torch.arange(60).float().reshape(10, 6), 6, 1), # top-6, dim=1
(torch.arange(60).float().reshape(2, 5, 6), 5, 1), # top-5, dim=1
],
)
def test_sst_dst_to_perfect_dense_reconstruction(dense, k, dim):
"""Tests whether perfect reconstruction of input dense tensor is generated when top-k matches the numel
across some dimension dim for both SST and DST.
"""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
sst = sparser.dense_to_sst(dense)
dst = sparser.dense_sst_to_dst(dense, sst)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense, dense_recons, raise_exception=True, rtol=RTOL, atol=ATOL)
def get_valid_conf_arg_list():
......@@ -182,15 +99,240 @@ def test_dense_to_sst_perfect_recons(tensor, dim):
when top_k_percent is set at 100.
"""
sparser_2d = SignalSparsity(sst_top_k_percent=100, sst_top_k_dim=dim, dst_top_k_percent=100)
assert all((sparser_2d.dense_to_sst(tensor) == torch.fft.fft(tensor)).flatten())
if dim is None:
fft_tensor = torch.fft.fft(tensor.flatten()).reshape(tensor.shape)
else:
fft_tensor = torch.fft.fft(tensor, dim=dim)
assert all((sparser_2d.dense_to_sst(tensor) == fft_tensor).flatten())
#
# Below are fixed input/output testing.
#
def get_test_params():
"""Helper function to create and return a list of tuples of the form:
(dense, expected_sst, expected_dst, expected_reconstructed_tensor (RT), dim, percent, top_k_element)
to be used as parameters for tests.
"""
# Input tensor 0.
# We use `sin()` below to make sure the top-2 values are not index sort
# sensitive. With just `arange()`, we get a linear line and the resulting
# FFT has many identical second-to-the-largest values. That make top-2 potentially
# non-deterministic and implementation dependent.
tensor_4x3_None = torch.arange(12).sin().reshape(4, 3).float()
# Values are: [[ 0.00000000000000000000, 0.84147095680236816406, 0.90929740667343139648],
# [ 0.14112000167369842529, -0.75680249929428100586, -0.95892429351806640625],
# [-0.27941548824310302734, 0.65698659420013427734, 0.98935824632644653320],
# [ 0.41211849451065063477, -0.54402112960815429688, -0.99999022483825683594]]
# SST: with dim=None, top-2
expd_sst_4x3_None = torch.tensor(
[
[0.0000 + 0.0000j, 0.0000 + 0.0000j, -1.3618 - 5.7650j],
[0.0000 + 0.0000j, 0.0000 + 0.0000j, 0.0000 + 0.0000j],
[0.0000 + 0.0000j, 0.0000 + 0.0000j, 0.0000 + 0.0000j],
[0.0000 + 0.0000j, -1.3618 + 5.7650j, 0.0000 + 0.0000j],
],
dtype=torch.complex64,
)
# DST: with dim=None, top-2
expd_dst_4x3_None = torch.tensor(
[
[0.22696666419506072998, 0.00000000000000000000, 0.00000000000000000000],
[0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000],
[0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000],
[0.18515183031558990479, 0.00000000000000000000, 0.00000000000000000000],
]
)
# RT: expected_reconstructed_tensor with dim=None and top-2 for both sst and dst
expd_rt_4x3_None = torch.tensor(
[
[0.00000000000000000000, 0.71862268447875976562, 0.94558942317962646484],
[0.22696666419506072998, -0.71862268447875976562, -0.94558942317962646484],
[-0.22696666419506072998, 0.71862268447875976562, 0.94558942317962646484],
[0.41211849451065063477, -0.71862268447875976562, -0.94558942317962646484],
]
)
# Input tensor 1.
tensor_4x3_0 = torch.arange(50, 62).sin().reshape(4, 3) / 100
# Values are: [[-0.00262374849990010262, 0.00670229177922010422, 0.00986627582460641861],
# [ 0.00395925156772136688, -0.00558789074420928955, -0.00999755132943391800],
# [-0.00521551026031374931, 0.00436164764687418938, 0.00992872659116983414],
# [ 0.00636737979948520660, -0.00304810609668493271, -0.00966117810457944870]]
# SST: with dim=0, top-1, (use top-1 because top-2 and top-3 would include some identical values)
expd_sst_4x3_0 = torch.tensor(
[
[0.0000 + 0.0j, 0.0000 + 0.0j, 0.0000 + 0.0j],
[0.0000 + 0.0j, 0.0000 + 0.0j, 0.0000 + 0.0j],
[-1.81658901274204254150e-02 + 0.0j, 1.96999348700046539307e-02 + 0.0j, 3.94537299871444702148e-02 + 0.0j],
[0.0000 + 0.0j, 0.0000 + 0.0j, 0.0000 + 0.0j],
],
dtype=torch.complex64,
)
# DST: with dim=0, top-2
expd_dst_4x3_0 = torch.tensor(
[
[0.00191772403195500374, 0.00000000000000000000, 0.00000000000000000000],
[0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000],
[0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000],
[0.00000000000000000000, 0.00187687762081623077, 0.00020225439220666885],
]
)
# RT: expected_reconstructed_tensor with dim=0 and top-2 for both sst and dst
expd_rt_4x3_0 = torch.tensor(
[
[-0.00262374849990010262, 0.00492498371750116348, 0.00986343249678611755],
[0.00454147253185510635, -0.00492498371750116348, -0.00986343249678611755],
[-0.00454147253185510635, 0.00492498371750116348, 0.00986343249678611755],
[0.00454147253185510635, -0.00304810609668493271, -0.00966117810457944870],
]
)
# Input tensor 2.
tensor_3x5_1 = torch.Tensor([0, 2, 3, 1, 6, 5, 7, 4, 8, 11, 9, 10, 0, 2, 5]).reshape(3, 5)
# SST: with dim=1, top-3, because FFT always have symmetric output after the top-1
expd_sst_3x5_1 = torch.tensor(
[
[
12.00000000000000000000 + 0.00000000000000000000j,
0,
-5.23606777191162109375 + 4.25325393676757812500j,
-5.23606777191162109375 - 4.25325393676757812500j,
0,
],
[
35.00000000000000000000 + 0.00000000000000000000j,
0,
-5.85410213470458984375 - 1.45308518409729003906j,
-5.85410213470458984375 + 1.45308518409729003906j,
0,
],
[
26.00000000000000000000 + 0.00000000000000000000j,
12.01722049713134765625 - 3.57971239089965820312j,
0,
0,
12.01722049713134765625 + 3.57971239089965820312j,
],
],
dtype=torch.complex64,
)
# DST: with dim=1, top-2
expd_dst_3x5_1 = torch.tensor(
[
[
0.00000000000000000000,
-1.09442710876464843750,
0.00000000000000000000,
0.86524754762649536133,
0.90557289123535156250,
],
[
0.00000000000000000000,
-2.23606777191162109375,
-1.72360706329345703125,
0.00000000000000000000,
2.44721317291259765625,
],
[
0.00000000000000000000,
1.95278644561767578125,
-2.15278673171997070312,
1.53049504756927490234,
0.00000000000000000000,
],
],
)
# RT: expected_reconstructed_tensor with dim=1 and top-2 for both sst and dst
expd_rt_3x5_1 = torch.tensor(
[
[
0.30557289719581604004,
2.00000000000000000000,
3.37082076072692871094,
1.00000000000000000000,
6.00000000000000000000,
],
[
4.65835905075073242188,
7.00000000000000000000,
4.00000000000000000000,
6.82917928695678710938,
11.00000000000000000000,
],
[
10.00688838958740234375,
10.00000000000000000000,
0.00000000000000000000,
2.00000000000000000000,
5.32360696792602539062,
],
]
)
# Input tensor 3.
tensor_3x2x2 = torch.arange(12).cos().reshape(3, 2, 2).float()
# Values are: [[[ 1.00000000000000000000, 0.54030233621597290039],
# [-0.41614684462547302246, -0.98999249935150146484]],
# [[-0.65364360809326171875, 0.28366219997406005859],
# [ 0.96017026901245117188, 0.75390225648880004883]],
# [[-0.14550003409385681152, -0.91113024950027465820],
# [-0.83907151222229003906, 0.00442569795995950699]]]
# SST: with dim=1, top-1
expd_sst_3x2x2_1 = torch.tensor(
[
[[0, 0], [1.41614687442779541016 + 0.0j, 1.53029489517211914062 + 0.0j]],
[[0, 1.03756451606750488281 + 0.0j], [-1.61381387710571289062 + 0.0j, 0]],
[[-0.98457157611846923828 + 0.0j, 0], [0, -0.91555595397949218750 + 0.0j]],
],
dtype=torch.complex64,
)
# DST: with dim=1, top-1
expd_dst_3x2x2_1 = torch.tensor(
[
[[0.00000000000000000000, -0.22484511137008666992], [0.29192659258842468262, 0.00000000000000000000]],
[[0.15326333045959472656, -0.23512005805969238281], [0.00000000000000000000, 0.00000000000000000000]],
[[0.34678575396537780762, -0.45335227251052856445], [0.00000000000000000000, 0.00000000000000000000]],
]
)
# RT: expected_reconstructed_tensor with dim=1 and top-1 for both sst and dst
expd_rt_3x2x2_1 = torch.tensor(
[
[[0.70807343721389770508, 0.54030233621597290039], [-0.41614684462547302246, -0.76514744758605957031]],
[[-0.65364360809326171875, 0.28366219997406005859], [0.80690693855285644531, 0.51878225803375244141]],
[[-0.14550003409385681152, -0.91113024950027465820], [-0.49228578805923461914, 0.45777797698974609375]],
]
)
return [
# input, expected sst, dst, rt, sst_dim, percent, top_k.
(tensor_4x3_None, expd_sst_4x3_None, expd_dst_4x3_None, expd_rt_4x3_None, None, 2 / 12 * 100, 2),
(tensor_4x3_0, expd_sst_4x3_0, expd_dst_4x3_0, expd_rt_4x3_0, 0, 1 / 3 * 100, 1),
(tensor_3x5_1, expd_sst_3x5_1, expd_dst_3x5_1, expd_rt_3x5_1, 1, 3 / 5 * 100, 3),
(tensor_3x2x2, expd_sst_3x2x2_1, expd_dst_3x2x2_1, expd_rt_3x2x2_1, 1, 1 / 2 * 100, 1),
]
@pytest.mark.parametrize("tensor, expd_sst, unused1, unused2, dim, unused3, k", get_test_params())
def test_dense_to_sst_fixed(tensor, expd_sst, unused1, unused2, dim, unused3, k):
def test_dense_to_sst(tensor, expd_sst, unused1, unused2, dim, unused3, k):
"""Tests for fixed input dense tensor and fixed expected output SST tensor."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sparser_2d = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst, expd_sst, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
@pytest.mark.parametrize("tensor, unused1, unused2, unused3, dim, percent, k", get_test_params())
......@@ -203,56 +345,37 @@ def test_percent_element(tensor, unused1, unused2, unused3, dim, percent, k):
sst_top_k_percent=percent, sst_top_k_element=None, sst_top_k_dim=dim, dst_top_k_percent=100
)
sst_percent = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst_element, sst_percent, raise_exception=True)
objects_are_equal(sst_element, sst_percent, raise_exception=True, rtol=RTOL, atol=ATOL)
@pytest.mark.parametrize("tensor, sst, expd_dst, unused1, dim, unused2, k", get_test_params())
def test_dense_sst_to_dst(tensor, sst, expd_dst, unused1, dim, unused2, k):
"""Tests fixed expected output DST tensor with fixed input dense and SST tensors."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, dst_top_k_element=k, dst_top_k_dim=dim)
sparser_2d = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dst = sparser_2d.dense_sst_to_dst(tensor, sst)
objects_are_equal(dst, expd_dst, raise_exception=True)
objects_are_equal(dst, expd_dst, raise_exception=True, rtol=RTOL, atol=ATOL)
@pytest.mark.parametrize(
"dense, k, dim",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 40, None), # top-40, dim=None
(torch.linspace(0.1, 0.6, 30).reshape(5, 6), 5, 0), # top-5, dim=0
(torch.linspace(-0.1, 0.6, 35).reshape(7, 5), 5, 1), # top-5, dim=1
(torch.arange(60).float().reshape(10, 6), 60, None), # top-60, dim=None
(torch.arange(60).float().reshape(10, 6), 10, 0), # top-10, dim=0
(torch.arange(60).float().reshape(10, 6), 6, 1), # top-6, dim=1
(torch.arange(60).float().reshape(2, 5, 6), 5, 1), # top-5, dim=1
],
)
def test_sst_dst_to_perfect_dense_reconstruction(dense, k, dim):
"""Tests whether perfect reconstruction of input dense tensor is generated when top-k matches the numel
across some dimension dim for both SST and DST.
"""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
sst = sparser.dense_to_sst(dense)
dst = sparser.dense_sst_to_dst(dense, sst)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense, dense_recons, raise_exception=True)
@pytest.mark.parametrize("unused1, sst, dst, expd_rt, dim, unused2, k", get_test_params())
def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
@pytest.mark.parametrize("unused1, sst, dst, expd_rt, dim, unused2, unused3", get_test_params())
def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, unused3):
"""Tests the correct expected reconstruction from frozen sst and dst tensors."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
sparser = SignalSparsity(sst_top_k_element=1, sst_top_k_dim=dim, dst_top_k_element=1, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True)
objects_are_equal(dense_recons, expd_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
@pytest.mark.parametrize("tensor, expd_sst, expd_dst, expd_rt, dim, unused, k", get_test_params())
def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k):
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k, device):
"""Tests the lossy_compress method against expected sst, dst and reconstruced tensor."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("no GPU")
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, expd_rt, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True)
objects_are_equal(dst, expd_dst, raise_exception=True)
lossy_dense, sst, dst = sparser.lossy_compress(tensor.to(device))
objects_are_equal(sst.to(device), expd_sst.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst.to(device), expd_dst.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(lossy_dense.to(device), expd_rt.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
@pytest.mark.parametrize(
......@@ -263,12 +386,16 @@ def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k):
(torch.linspace(-10, 15, 36).reshape(6, 6), 1, 100),
],
)
def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent):
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent, device):
"""Tests whether lossy_compress method simply returns dense tensor when sparsity is 0."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("no GPU")
sparser = SignalSparsity(
sst_top_k_percent=top_k_percent, sst_top_k_dim=dim, dst_top_k_percent=top_k_percent, dst_top_k_dim=dim
)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, tensor, raise_exception=True)
objects_are_equal(sst, None, raise_exception=True)
objects_are_equal(dst, tensor, raise_exception=True)
lossy_dense, sst, dst = sparser.lossy_compress(tensor.to(device))
objects_are_equal(lossy_dense.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, None, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment