Unverified Commit 4c830de1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] bugs in signal sparsity class and improving tests (#1058)



* update examples and comment

* fixed issue with fft/ifft only doing the last dim

* fixed a int/round bug; fixed tests

* add cuda tests

* add atol and rtol

* skip cuda test correctly
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f81a60be
...@@ -477,7 +477,14 @@ class GPT2(Base): ...@@ -477,7 +477,14 @@ class GPT2(Base):
return self.clf_head(h), logits return self.clf_head(h), logits
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool: def objects_are_equal(
a: Any,
b: Any,
raise_exception: bool = False,
dict_key: Optional[str] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> bool:
""" """
Test that two objects are equal. Tensors are compared to ensure matching Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values. size, dtype, device and values.
...@@ -515,9 +522,9 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -515,9 +522,9 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return False return False
# assert_close. # assert_close.
if torch_version() < (1, 12, 0): if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b) torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
else: else:
torch.testing.assert_close(a, b) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
return True return True
except (AssertionError, RuntimeError) as e: except (AssertionError, RuntimeError) as e:
if raise_exception: if raise_exception:
......
...@@ -17,7 +17,7 @@ def _get_k_for_topk(topk_percent: Optional[float], top_k_element: Optional[int], ...@@ -17,7 +17,7 @@ def _get_k_for_topk(topk_percent: Optional[float], top_k_element: Optional[int],
simply returns the value for k. Also, ensures k is never 0 to avoid all-zero tensors. simply returns the value for k. Also, ensures k is never 0 to avoid all-zero tensors.
""" """
if top_k_element is None: if top_k_element is None:
top_k_element = int(top_k_total_size * topk_percent / 100.0) top_k_element = round(top_k_total_size * topk_percent / 100.0)
elif top_k_element > top_k_total_size: elif top_k_element > top_k_total_size:
raise ValueError("top_k_element for sst or dst is larger than max number of elements along top_k_dim") raise ValueError("top_k_element for sst or dst is larger than max number of elements along top_k_dim")
# ensure we never have 100% sparsity in tensor and always have 1 surviving element! # ensure we never have 100% sparsity in tensor and always have 1 surviving element!
...@@ -86,12 +86,73 @@ def _is_sparsity_zero( ...@@ -86,12 +86,73 @@ def _is_sparsity_zero(
return k == top_k_total_size return k == top_k_total_size
def _dct_transform(dense: Tensor) -> Tensor: def _fft_transform(dense: Tensor, dim: int) -> Tensor:
"""Wrapper of torch.fft.fft with more flexibility on dimensions.
TODO (Min): figure out if we need to change other args like frequency length, n, or
the normalization flag.
For our use case, we use fft not rfft since we want big magnitute components from
both positive and negative frequencies.
Args:
dense (Tensor):
Input dense tensor (no zeros).
dim (int):
Which dimension to transform.
Returns:
(Tensor, complex):
transformed dense tensor FFT components.
"""
orig_shape = None
if dim is None:
orig_shape = dense.shape
dense = dense.reshape(-1)
dim = -1
ret = torch.fft.fft(dense, dim=dim)
if orig_shape is not None:
ret = ret.reshape(orig_shape)
return ret
def _ifft_transform(sst: Tensor, dim: int) -> Tensor:
"""Wrapper of torch.fft.ifft with more flexibility on dimensions.
Args:
sst (Tensor):
Input sst tensor (may have zeros) in frequency domain.
dim (int):
Which dimension to transform.
Returns:
(Tensor):
A new, transformed dense tensor with real domain values.
"""
assert sst.is_complex()
orig_shape = None
if dim is None:
orig_shape = sst.shape
sst = sst.reshape(-1)
dim = -1
ret = torch.fft.ifft(sst, dim=dim)
if orig_shape is not None:
ret = ret.reshape(orig_shape)
return ret
def _dct_transform(dense: Tensor, dim: int) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor. """Should take a tensor and perform a Discrete Cosine Transform on the tensor.
Args: Args:
dense (Tensor): dense (Tensor):
Input dense tensor (no zeros). Input dense tensor (no zeros).
dim (int):
Which dimension to transform.
Returns: Returns:
(Tensor): (Tensor):
transformed dense tensor DCT components transformed dense tensor DCT components
...@@ -99,12 +160,14 @@ def _dct_transform(dense: Tensor) -> Tensor: ...@@ -99,12 +160,14 @@ def _dct_transform(dense: Tensor) -> Tensor:
raise NotImplementedError("Support for DCT has not been implemented yet!") raise NotImplementedError("Support for DCT has not been implemented yet!")
def _inverse_dct_transform(sst: Tensor) -> Tensor: def _idct_transform(sst: Tensor, dim: int) -> Tensor:
"""Should take a tensor and perform an inverse Discrete Cosine Transform and return a new tensor. """Should take a tensor and perform an inverse Discrete Cosine Transform and return a new tensor.
Args: Args:
sst (Tensor): sst (Tensor):
Input sst tensor (may have zeros) in frequency domain. Input sst tensor (may have zeros) in frequency domain.
dim (int):
Which dimension to transform.
Returns: Returns:
(Tensor): (Tensor):
A new, transformed dense tensor with real domain values. A new, transformed dense tensor with real domain values.
...@@ -128,6 +191,9 @@ class SignalSparsity: ...@@ -128,6 +191,9 @@ class SignalSparsity:
`sst_top_k_element` or `sst_top_k_percent` and also requires a `sst_top_k_element` or `sst_top_k_percent` and also requires a
value for one of `dst_top_k_element` or `dst_top_k_percent`. value for one of `dst_top_k_element` or `dst_top_k_percent`.
This class only handles tensor inputs and outputs. We leave
state_dict type of data handling to upper layer functions.
Args: Args:
algo (Algo): algo (Algo):
The algorithm used. Default: FFT The algorithm used. Default: FFT
...@@ -153,11 +219,11 @@ class SignalSparsity: ...@@ -153,11 +219,11 @@ class SignalSparsity:
Example: Example:
.. code-block:: python .. code-block:: python
2d_sparser = SignalSparsity() 2d_sparser = SignalSparsity(sst_top_k_element=10, dst_top_k_element=1)
sst, dst = 2d_sparser.get_sst_dst(linear.weight.data) sst = 2d_sparser.dense_to_sst(linear.weight.data)
3d_sparser = SingalSparsity(algo=Algo.DCT, sst_top_k_dim=None, dst_top_k_dim=-1, sst_top_k_percent=10, dst_top_k_element=100) 3d_sparser = SingalSparsity(algo=Algo.FFT, sst_top_k_dim=None, dst_top_k_dim=-1, sst_top_k_percent=10, dst_top_k_element=100)
conv.weight.data = 3d_sparser.get_sst_dst_weight(conv.weight.data) conv.weight.data, _, _ = 3d_sparser.lossy_compress(conv.weight.data)
""" """
def __init__( def __init__(
...@@ -180,7 +246,9 @@ class SignalSparsity: ...@@ -180,7 +246,9 @@ class SignalSparsity:
self._validate_conf() self._validate_conf()
# TODO (Min): Type checking for the following # TODO (Min): Type checking for the following
self._transform, self._inverse_transform = (torch.fft.fft, torch.fft.ifft) if algo is Algo.FFT else (_dct_transform, _inverse_dct_transform) # type: ignore self._transform, self._inverse_transform = (
(_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform)
)
def _validate_conf(self) -> None: def _validate_conf(self) -> None:
"""Validating if the config is valid. """Validating if the config is valid.
...@@ -248,13 +316,13 @@ class SignalSparsity: ...@@ -248,13 +316,13 @@ class SignalSparsity:
""" """
top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim) top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size) k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
dense_freq = self._transform(dense) dense_freq = self._transform(dense, dim=self._sst_top_k_dim)
# NOTE: real_dense_freq can potentially be magnitude of complex frequency components # NOTE: real_dense_freq can potentially be magnitude of complex frequency components
# or DCT transformed components when using DCT (currently not implemented). # or DCT transformed components when using DCT (currently not implemented).
# TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be # TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be
# done on the smaller phases. # done on the smaller phases.
real_dense_freq = torch.real(dense_freq).abs() real_dense_freq = dense_freq.real.abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim) return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor: def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
...@@ -295,7 +363,7 @@ class SignalSparsity: ...@@ -295,7 +363,7 @@ class SignalSparsity:
(Tensor): (Tensor):
A dense tensor in real number domain from the SST. A dense tensor in real number domain from the SST.
""" """
dense_rt = torch.real(self._inverse_transform(sst)) dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim))
if dst is not None: if dst is not None:
dense_rt += dst dense_rt += dst
return dense_rt return dense_rt
...@@ -328,9 +396,3 @@ class SignalSparsity: ...@@ -328,9 +396,3 @@ class SignalSparsity:
sst = self.dense_to_sst(dense) sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst) dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst return self.sst_dst_to_dense(sst, dst), sst, dst
# We could separate have helper functions that work on state_dict instead of a tensor.
# One option is to extend the above class to handle state_dict as well as tensor
# but we may want to filter on the keys in the state_dict, so maybe that option isn't
# the best. We need to have further discussions on this.
...@@ -30,6 +30,7 @@ from . import nn as nn ...@@ -30,6 +30,7 @@ from . import nn as nn
from . import testing as testing from . import testing as testing
from . import utils as utils from . import utils as utils
from . import jit as jit from . import jit as jit
from . import fft as fft
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from . import backends from . import backends
...@@ -117,6 +118,7 @@ class Tensor: ...@@ -117,6 +118,7 @@ class Tensor:
grad: Optional[Tensor] = ... grad: Optional[Tensor] = ...
data: Tensor = ... data: Tensor = ...
names: List[str] = ... names: List[str] = ...
real: Tensor = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL #MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ... _has_been_cloned: Optional[bool] = ...
...@@ -395,7 +397,6 @@ class Tensor: ...@@ -395,7 +397,6 @@ class Tensor:
def expm1(self) -> Tensor: ... def expm1(self) -> Tensor: ...
def expm1_(self) -> Tensor: ... def expm1_(self) -> Tensor: ...
def exponential_(self, lambd: _float=1, *, generator: Generator=None) -> Tensor: ... def exponential_(self, lambd: _float=1, *, generator: Generator=None) -> Tensor: ...
def fft(self, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
@overload @overload
def fill_(self, value: Number) -> Tensor: ... def fill_(self, value: Number) -> Tensor: ...
@overload @overload
...@@ -453,7 +454,6 @@ class Tensor: ...@@ -453,7 +454,6 @@ class Tensor:
def half(self) -> Tensor: ... def half(self) -> Tensor: ...
def hardshrink(self, lambd: Number=0.5) -> Tensor: ... def hardshrink(self, lambd: Number=0.5) -> Tensor: ...
def histc(self, bins: _int=100, min: Number=0, max: Number=0) -> Tensor: ... def histc(self, bins: _int=100, min: Number=0, max: Number=0) -> Tensor: ...
def ifft(self, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
def imag(self) -> Tensor: ... def imag(self) -> Tensor: ...
@overload @overload
def index_add(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ... def index_add(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
...@@ -494,7 +494,6 @@ class Tensor: ...@@ -494,7 +494,6 @@ class Tensor:
def int(self) -> Tensor: ... def int(self) -> Tensor: ...
def int_repr(self) -> Tensor: ... def int_repr(self) -> Tensor: ...
def inverse(self) -> Tensor: ... def inverse(self) -> Tensor: ...
def irfft(self, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True, signal_sizes: _size=()) -> Tensor: ...
def is_coalesced(self) -> _bool: ... def is_coalesced(self) -> _bool: ...
def is_complex(self) -> _bool: ... def is_complex(self) -> _bool: ...
def is_contiguous(self) -> _bool: ... def is_contiguous(self) -> _bool: ...
...@@ -683,7 +682,6 @@ class Tensor: ...@@ -683,7 +682,6 @@ class Tensor:
def random_(self, to: _int, *, generator: Generator=None) -> Tensor: ... def random_(self, to: _int, *, generator: Generator=None) -> Tensor: ...
@overload @overload
def random_(self, *, generator: Generator=None) -> Tensor: ... def random_(self, *, generator: Generator=None) -> Tensor: ...
def real(self) -> Tensor: ...
def reciprocal(self) -> Tensor: ... def reciprocal(self) -> Tensor: ...
def reciprocal_(self) -> Tensor: ... def reciprocal_(self) -> Tensor: ...
def refine_names(self, names: List[Union[str, None]]) -> Tensor: ... def refine_names(self, names: List[Union[str, None]]) -> Tensor: ...
...@@ -720,7 +718,6 @@ class Tensor: ...@@ -720,7 +718,6 @@ class Tensor:
@overload @overload
def resize_(self, *size: _int, memory_format: Optional[memory_format]=None) -> Tensor: ... def resize_(self, *size: _int, memory_format: Optional[memory_format]=None) -> Tensor: ...
def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
def rfft(self, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True) -> Tensor: ...
def roll(self, shifts: Union[_int, _size], dims: Union[_int, _size]=()) -> Tensor: ... def roll(self, shifts: Union[_int, _size], dims: Union[_int, _size]=()) -> Tensor: ...
def rot90(self, k: _int=1, dims: _size=(0,1)) -> Tensor: ... def rot90(self, k: _int=1, dims: _size=(0,1)) -> Tensor: ...
def round(self) -> Tensor: ... def round(self) -> Tensor: ...
...@@ -1317,7 +1314,6 @@ def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... ...@@ -1317,7 +1314,6 @@ def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_alpha_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ... def feature_alpha_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ...
def feature_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ... def feature_dropout_(self: Tensor, p: _float, train: _bool) -> Tensor: ...
def fft(self: Tensor, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
@overload @overload
def fill_(self: Tensor, value: Number) -> Tensor: ... def fill_(self: Tensor, value: Number) -> Tensor: ...
@overload @overload
...@@ -1394,7 +1390,6 @@ def hann_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, lay ...@@ -1394,7 +1390,6 @@ def hann_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, lay
def hardshrink(self: Tensor, lambd: Number=0.5) -> Tensor: ... def hardshrink(self: Tensor, lambd: Number=0.5) -> Tensor: ...
def histc(self: Tensor, bins: _int=100, min: Number=0, max: Number=0, *, out: Optional[Tensor]=None) -> Tensor: ... def histc(self: Tensor, bins: _int=100, min: Number=0, max: Number=0, *, out: Optional[Tensor]=None) -> Tensor: ...
def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def ifft(self: Tensor, signal_ndim: _int, normalized: _bool=False) -> Tensor: ...
def imag(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def imag(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def index_add(self: Tensor, dim: _int, index: Tensor, source: Tensor) -> Tensor: ... def index_add(self: Tensor, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
...@@ -1421,7 +1416,6 @@ def index_select(self: Tensor, dim: Union[str, None], index: Tensor, *, out: Opt ...@@ -1421,7 +1416,6 @@ def index_select(self: Tensor, dim: Union[str, None], index: Tensor, *, out: Opt
def instance_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... def instance_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ...
def int_repr(self: Tensor) -> Tensor: ... def int_repr(self: Tensor) -> Tensor: ...
def inverse(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def inverse(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def irfft(self: Tensor, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True, signal_sizes: _size=()) -> Tensor: ...
def is_complex(self: Tensor) -> _bool: ... def is_complex(self: Tensor) -> _bool: ...
def is_distributed(self: Tensor) -> _bool: ... def is_distributed(self: Tensor) -> _bool: ...
def is_floating_point(self: Tensor) -> _bool: ... def is_floating_point(self: Tensor) -> _bool: ...
...@@ -1705,7 +1699,6 @@ def result_type(tensor: Tensor, other: Number) -> _dtype: ... ...@@ -1705,7 +1699,6 @@ def result_type(tensor: Tensor, other: Number) -> _dtype: ...
def result_type(scalar: Number, tensor: Tensor) -> _dtype: ... def result_type(scalar: Number, tensor: Tensor) -> _dtype: ...
@overload @overload
def result_type(scalar1: Number, scalar2: Number) -> _dtype: ... def result_type(scalar1: Number, scalar2: Number) -> _dtype: ...
def rfft(self: Tensor, signal_ndim: _int, normalized: _bool=False, onesided: _bool=True) -> Tensor: ...
@overload @overload
def rnn_relu(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... def rnn_relu(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ...
@overload @overload
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from torch import Tensor
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
#_int = builtins.int
#_float = builtins.float
#_bool = builtins.bool
#_size = Union[Size, List[int], Tuple[int, ...]]
def fft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def ifft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def rfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
def irfft(input: Tensor, n: Optional[int] = None, dim: Optional[int]=-1, norm: Optional[str]=None) -> Tensor: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment