Unverified Commit 608492af authored by Riyasat Ohib's avatar Riyasat Ohib Committed by GitHub
Browse files

[Feat] dense to sst implementation (#1034)

* [Feat] dense to sst implementation
1. Implementation of dense_to_sst function.
2. calculating the threshold for both the cases of top-k-element and top-k-percentage (fraction)
3. assertions to verify that the top_k_elements is smaller than the numel along the same dim
4. top_k_percent to top-k conversion
5. When calculating SST, now the real part of the complex dense_freq is used instead of the magnitudes.

* [Feat, Tests] transform method addition, handling of top_k_element None case
1. Addition of a transform method
2. Adds code to handle the dim=None case for top_k_element

* [Feat, Refactor] Reorganizations, new assertions and fixes.
1. XOR for validation that both of topk percent and element are not set, or both simultaneously unset. One and only one is set.
3. Distills topk and percent both to topk using unified helper function .
5. Adds a scatter topk values function to scatter values for SST and in future DST.
6. Validation for percentage range, and ensures k is never 0.
7. Uses config validation, adds config validation for top_k_element > 0 if not None.
parent c3f88a6d
......@@ -4,10 +4,90 @@
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Optional
import torch
from torch import Tensor
# Helper Functions
def _get_k_for_topk(topk_percent: Optional[float], top_k_element: Optional[int], top_k_total_size: int) -> int:
"""Converts the top_k_percent to top_k_element when top_k_percent is provided
as the criterion for top-k calculation. When, top_k_element is used as the criterion,
simply returns the value for k. Also, ensures k is never 0 to avoid all-zero tensors.
"""
if top_k_element is None:
top_k_element = int(top_k_total_size * topk_percent / 100.0)
elif top_k_element > top_k_total_size:
raise ValueError("top_k_element for sst or dst is larger than max number of elements along top_k_dim")
# ensure we never have 100% sparsity in tensor and always have 1 surviving element!
return max(1, top_k_element)
def _scatter_topk_to_sparse_tensor(
top_k_tensor: Tensor, to_be_sparsify_tensor: Tensor, k: int, dim: Optional[int]
) -> Tensor:
"""Scatter the topk values of the to_be_sparsify_tensor to a zero tensor of the same shape
at the top-k indices of the top_k_tensor. This function allows top-k computation with a
derived tensor from to_be_sparsify_tensor.
Args:
top_k_tensor (Tensor):
The source tensor whose top-k "indices" are taken and used to extract
the corresponding "values" from the to_be_sparsify_tensor.
to_be_sparsify_tensor (Tensor):
The tensor whose values are gathered according to the top-k indices
of the top_k_tensor, and a zero tensor of same shape is populated with these
values at those indices and creates the sparse_tensor tensor.
k (int):
the value of k for top-k
dim (Optional[int]):
dimension for top-k
Returns:
(Tensor):
Returns a sparse_tensor with the same shape as the top_k_tensor and to_be_sparsify_tensor,
and populated with the values of the to_be_sparsify_tensor at the indices corresponding
to the top-k indices of the source tensor.
"""
assert (
top_k_tensor.shape == to_be_sparsify_tensor.shape
), "top_k_tensor and to_be_sparsify_tensor have different shapes!"
sparse_tensor = torch.zeros_like(to_be_sparsify_tensor)
orig_shape = sparse_tensor.shape
if dim is None and len(orig_shape) > 1:
sparse_tensor = sparse_tensor.reshape(-1)
to_be_sparsify_tensor = to_be_sparsify_tensor.reshape(-1)
top_k_tensor = top_k_tensor.reshape(-1)
dim = -1
_, i = top_k_tensor.topk(k, dim=dim)
return sparse_tensor.scatter(dim, i, to_be_sparsify_tensor.gather(dim, i)).reshape(orig_shape)
def _top_k_total_size(tensor: Tensor, topk_dim: Optional[int]) -> int:
"""Get the total size of the input tensor along the topk_dim dimension. When, the
dimension is None, get the number of elements in the tensor.
"""
top_k_total_size = tensor.numel() if topk_dim is None else tensor.shape[topk_dim]
assert top_k_total_size > 0, "Total size of input tensor along the topk_dim has to be greater than 0."
return top_k_total_size
def _dct_transform(dense: Tensor) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor.
Args:
dense (Tensor):
Input dense tensor (no zeros).
Returns:
(Tensor):
transformed dense tensor DCT components
"""
raise NotImplementedError("Support for DCT has not been implemented yet!")
class Algo(Enum):
FFT = 0
DCT = 1
......@@ -20,6 +100,10 @@ class SignalSparsity:
be used both on weights, gradients and other tensors like the
optimizer state.
During initialization, this class requires a value for one of
`sst_top_k_element` or `sst_top_k_percent` and also requires a
value for one of `dst_top_k_element` or `dst_top_k_percent`.
Args:
algo (Algo):
The algorithm used. Default: FFT
......@@ -31,7 +115,7 @@ class SignalSparsity:
sst_top_k_element (int, optional):
Number of top-k elements to retain for SST. Default: None
sst_top_k_percent (float, optional):
Percent of top-k elements to retain for SST. Default: 0.1
Percent of top-k elements to retain for SST. Default: None
dst_top_k_dim (int, optional):
The dimension on which the top-k is done for DST.
E.g. -1 is the last dim. None means flatten and top-k on all dims.
......@@ -40,7 +124,7 @@ class SignalSparsity:
dst_top_k_element (int, optional):
Number of top-k elements to retain for DST. Default: None
dst_top_k_percent (float, optional):
Percent of top-k elements to retain for DST. Default: 0.1
Percent of top-k elements to retain for DST. Default: None
Example:
.. code-block:: python
......@@ -48,32 +132,106 @@ class SignalSparsity:
2d_sparser = SignalSparsity()
sst, dst = 2d_sparser.get_sst_dst(linear.weight.data)
3d_sparser = SingalSparsity(algo=Algo.DCT, sst_top_k_dim=None, dst_top_k_dim=-1, dst_top_k_element=5, dst_top_k_percent=None)
3d_sparser = SingalSparsity(algo=Algo.DCT, sst_top_k_dim=None, dst_top_k_dim=-1, sst_top_k_percent=10, dst_top_k_element=100)
conv.weight.data = 3d_sparser.get_sst_dst_weight(conv.weight.data)
"""
def __init__(self) -> None:
pass
def __init__(
self,
algo: Algo = Algo.FFT,
sst_top_k_dim: Optional[int] = -1,
sst_top_k_element: Optional[int] = None,
sst_top_k_percent: Optional[float] = None,
dst_top_k_dim: Optional[int] = -1,
dst_top_k_element: Optional[int] = None,
dst_top_k_percent: Optional[float] = None,
) -> None:
self._sst_top_k_dim = sst_top_k_dim
self._sst_top_k_element = sst_top_k_element
self._sst_top_k_percent = sst_top_k_percent
self._dst_top_k_dim = dst_top_k_dim
self._dst_top_k_element = dst_top_k_element
self._dst_top_k_percent = dst_top_k_percent
self._validate_conf()
# TODO (Min): Type checking for the following
self._transform = torch.fft.fft if algo is Algo.FFT else _dct_transform # type: ignore
def _validate_conf(self) -> None:
"""Validating the config is valid.
"""Validating if the config is valid.
For example, not both top_k_element and top_k_percent is set.
This includes asserting the following:
1. validating that one and only one of top_k_element and top_k_percent is set.
2. Asserting that both element and percentage are in valid ranges.
this should assert fail if checking fails.
Throws:
ValueError:
If validation fails.
"""
pass
# assert that both top_k_elements and top_k_percent aren't set for sst and dst
def one_and_only(a: Optional[int], b: Optional[float]) -> bool:
return (a is None) ^ (b is None)
if not (
one_and_only(self._sst_top_k_element, self._sst_top_k_percent)
and one_and_only(self._dst_top_k_element, self._dst_top_k_percent)
):
raise ValueError(
"One and only one of top_k_element and top_k_percent for "
"each of sst and dst must be provided as an argument.\n"
f"Input values are: sst element={self._sst_top_k_element}, sst percent={self._sst_top_k_percent}, "
f"dst element={self._dst_top_k_element}, dst percent={self._dst_top_k_percent}"
)
# assert that, if top_k_percent is not None, it is a valid number for a percentage.
def none_or_in_range(a: Optional[float]) -> bool:
return a is None or (0.0 < a <= 100.0)
if not (none_or_in_range(self._sst_top_k_percent) and none_or_in_range(self._dst_top_k_percent)):
raise ValueError(
"top_k_percent values for sst and dst has to be in the interval (0, 100].\n"
f"Input values are: sst percent={self._sst_top_k_percent}, dst percent={self._dst_top_k_percent}"
)
def none_or_greater_0(a: Optional[int]) -> bool:
return a is None or (0 < a)
if not (none_or_greater_0(self._sst_top_k_element) and none_or_greater_0(self._dst_top_k_element)):
raise ValueError(
"top_k_element values for sst and dst has to be greater than 0.\n"
f"Input values are: sst element={self._sst_top_k_element} "
f"and dst element={self._dst_top_k_element}"
)
def dense_to_sst(self, dense: Tensor) -> Tensor:
"""Get SST from a tensor
"""Get Signal Sparse Tensor (SST) from a dense tensor
Dense -> fft -> top-k -> results.
The input dense tensor is transformed using a transform algorithm according to the `algo`
initialization argument. The SST is then generated from the top_k_elements
(or the top_k_percentage) of values from the transformed tensor along the 'sst_top_k_dim'.
Args:
dense (Tensor):
Input dense tensor (no zeros).
Returns:
Same shaped tensor, still in dense format but in frequency domain and has zeros.
(Tensor):
Same shaped tensor as the input dense tensor, still in dense format but in frequency
domain (complex valued) and has zeros.
"""
pass
top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
dense_freq = self._transform(dense)
# NOTE: real_dense_freq can potentially be magnitude of complex frequency components
# or DCT transformed components when using DCT (currently not implemented).
# TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be
# done on the smaller phases.
real_dense_freq = dense_freq.real.abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
"""From dense and SST to a DST
......@@ -89,6 +247,7 @@ class SignalSparsity:
Input SST tensor (has zeros).
Returns:
(Tensor):
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
"""
pass
......@@ -108,6 +267,7 @@ class SignalSparsity:
Delta sparse tensor, optional.
Returns:
(Tensor):
A dense tensor in real number domain from the SST.
"""
pass
......
......@@ -27,3 +27,4 @@ tests/experimental/wgit/test_cli.py
tests/experimental/wgit/test_api.py
tests/experimental/wgit/test_pygit.py
tests/experimental/wgit/test_sha1_store.py
tests/experimental/wgit/test_signal_sparsity.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity
def get_test_params():
"""Helper function to create and return a list of tuples of the form:
(in_tensor, expected_tensor, dim, percent, top_k_element) to be used as parameters for tests.
"""
# input in_tensors
tensor_4x3 = torch.arange(12).reshape(4, 3)
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2)
# Expected SST output tensors for 4x3 tensor of ascending ints
expected_4x3_None = torch.tensor(
[
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], # with dim=None, top-2
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[21.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[30.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
],
dtype=torch.complex64,
)
expected_4x3_0 = torch.tensor(
[
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j], # with dim=0, top-2
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j],
],
dtype=torch.complex64,
)
expected_4x3_1 = torch.tensor(
[
[3.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j], # with dim=1, top-2
[12.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
],
dtype=torch.complex64,
)
expected_2x2x3_1 = torch.tensor(
[
[[1.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, -1.0 + 0.0j]], # with dim=1, top-2
[[9.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, -1.0 + 0.0j]],
[[17.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, -1.0 + 0.0j]],
],
dtype=torch.complex64,
)
return [
(tensor_4x3, expected_4x3_None, None, 20, 2),
(tensor_4x3, expected_4x3_0, 0, 50, 2),
(tensor_4x3, expected_4x3_1, 1, 70, 2),
(tensor_2x2x3, expected_2x2x3_1, 1, 100, 2),
]
def get_valid_conf_arg_list():
"""Returns a map object of keyword arguments (as dicts) to be used as parameters for test_validate_conf."""
def kwargs(vals_list):
"""Maps the values in input vals_list to the keys in arg_key_list"""
arg_key_list = [
"sst_top_k_element",
"sst_top_k_percent",
"sst_top_k_dim",
"dst_top_k_element",
"dst_top_k_percent",
"dst_top_k_dim",
]
return dict(zip(arg_key_list, vals_list))
# Validate value error is raised when, either:
# 1. One and only one of sst (or dst) percent and element is not provided a value (not None).
# 2. Both of sst (or dst) percent and element is set to None.
# 3. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
element = 10
percent = 50
dim = 0
args_list = [
[element, percent, dim, element, None, dim], # case 1.
[element, None, dim, element, percent, dim],
[None, None, dim, element, None, dim], # case 2.
[element, None, dim, None, None, dim],
[0, None, dim, None, None, dim], # case 3.
[None, 0, dim, None, None, dim],
[element, None, dim, 0, None, dim],
[element, None, dim, None, 0, dim],
]
return map(kwargs, args_list)
@pytest.mark.parametrize("kwargs", get_valid_conf_arg_list())
def test_validate_conf(kwargs):
"""Validate value error is raised with each kwargs returned by get_valid_conf_arg_list"""
pytest.raises(ValueError, SignalSparsity, **kwargs)
@pytest.mark.parametrize(
"tensor, dim",
[
(torch.arange(20).reshape(4, 5), None),
(torch.arange(20).reshape(4, 5), 0),
(torch.arange(20).reshape(4, 5), 1),
(torch.arange(80).reshape(4, 5, 4), None),
(torch.arange(80).reshape(4, 5, 4), 0),
(torch.arange(80).reshape(4, 5, 4), 1),
(torch.arange(80).reshape(4, 5, 4), 2),
],
)
def test_dense_to_sst_perfect_recons(tensor, dim):
"""Tests the dense_to_sst method whether it simply performs an FFT transformation
when top_k_percent is set at 100.
"""
sparser_2d = SignalSparsity(sst_top_k_percent=100, sst_top_k_dim=dim, dst_top_k_percent=100)
assert all((sparser_2d.dense_to_sst(tensor) == torch.fft.fft(tensor)).flatten())
@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params())
def test_dense_to_sst_fixed(tensor, expected, dim, percent, k):
"""Tests for fixed input dense tensor and fixed expected output SST tensor for top-2 elements."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst, expected, raise_exception=True)
@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params())
def test_percent_element(tensor, expected, dim, percent, k):
"""Tests whether comparative values for top_k_element and top_k_percent returns same outputs"""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst_element = sparser_2d.dense_to_sst(tensor)
sparser_2d = SignalSparsity(
sst_top_k_percent=percent, sst_top_k_element=None, sst_top_k_dim=dim, dst_top_k_percent=100
)
sst_percent = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst_element, sst_percent, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment