Unverified Commit 22165cea authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Doc] update compression reference (#4667)

parent de6662a4
Compression API Reference
=========================
Pruner
------
Please refer to :doc:`../compression/pruner`.
Quantizer
---------
Please refer to :doc:`../compression/quantizer`.
Pruning Speedup
---------------
.. autoclass:: nni.compression.pytorch.speedup.ModelSpeedup
:members:
Quantization Speedup
--------------------
.. autoclass:: nni.compression.pytorch.quantization_speedup.ModelSpeedupTensorRT
:members:
Compression Utilities
---------------------
.. autoclass:: nni.compression.pytorch.utils.sensitivity_analysis.SensitivityAnalysis
:members:
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.ChannelDependency
:members:
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.ChannelMaskConflict
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
:members:
.. autofunction:: nni.compression.pytorch.utils.counter.count_flops_params
.. autofunction:: nni.algorithms.compression.v2.pytorch.utils.pruning.compute_sparsity
Framework Related
-----------------
.. autoclass:: nni.algorithms.compression.v2.pytorch.base.Pruner
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.base.PrunerModuleWrapper
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.basic_pruner.BasicPruner
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.tools.DataCollector
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.tools.MetricsCalculator
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.tools.SparsityAllocator
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.base.BasePruningScheduler
:members:
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.tools.TaskGenerator
:members:
.. autoclass:: nni.compression.pytorch.compressor.Quantizer
:members:
.. autoclass:: nni.compression.pytorch.compressor.QuantizerModuleWrapper
:members:
.. autoclass:: nni.compression.pytorch.compressor.QuantGrad
:members:
Model Compression
=================
nni.algorithms.compression
--------------------------
nni.compression
---------------
...@@ -6,7 +6,7 @@ API Reference ...@@ -6,7 +6,7 @@ API Reference
Hyperparameter Optimization <hpo> Hyperparameter Optimization <hpo>
Neural Architecture Search <./python_api/nas> Neural Architecture Search <./python_api/nas>
Model Compression <./python_api/compression> Model Compression <compression>
Feature Engineering <./python_api/feature_engineering> Feature Engineering <./python_api/feature_engineering>
Experiment <experiment> Experiment <experiment>
Others <./python_api/others> Others <./python_api/others>
...@@ -51,7 +51,7 @@ ...@@ -51,7 +51,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Pruning Model\n\nUsing L1NormPruner pruning the model and generating the masks.\nUsually, pruners require original model and ``config_list`` as parameters.\nDetailed about how to write ``config_list`` please refer ...\n\nThis `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,\nexcept the layer named `fc3`, because `fc3` is `exclude`.\nThe final sparsity ratio for each layer is 50%. The layer named `fc3` will not be pruned.\n\n" "## Pruning Model\n\nUsing L1NormPruner pruning the model and generating the masks.\nUsually, pruners require original model and ``config_list`` as parameters.\nDetailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.\n\nThis `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,\nexcept the layer named `fc3`, because `fc3` is `exclude`.\nThe final sparsity ratio for each layer is 50%. The layer named `fc3` will not be pruned.\n\n"
] ]
}, },
{ {
......
...@@ -50,7 +50,7 @@ for epoch in range(3): ...@@ -50,7 +50,7 @@ for epoch in range(3):
# #
# Using L1NormPruner pruning the model and generating the masks. # Using L1NormPruner pruning the model and generating the masks.
# Usually, pruners require original model and ``config_list`` as parameters. # Usually, pruners require original model and ``config_list`` as parameters.
# Detailed about how to write ``config_list`` please refer ... # Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
# #
# This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned, # This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,
# except the layer named `fc3`, because `fc3` is `exclude`. # except the layer named `fc3`, because `fc3` is `exclude`.
......
c87607b7befe8496829a8cb5a8632019 bacea60d39b0445d01e3a233b1bfd249
\ No newline at end of file \ No newline at end of file
...@@ -102,9 +102,9 @@ If you are familiar with defining a model and training in pytorch, you can skip ...@@ -102,9 +102,9 @@ If you are familiar with defining a model and training in pytorch, you can skip
.. code-block:: none .. code-block:: none
Average test loss: 0.5603, Accuracy: 8270/10000 (83%) Average test loss: 0.5876, Accuracy: 8158/10000 (82%)
Average test loss: 0.2395, Accuracy: 9289/10000 (93%) Average test loss: 0.2501, Accuracy: 9217/10000 (92%)
Average test loss: 0.1660, Accuracy: 9527/10000 (95%) Average test loss: 0.1786, Accuracy: 9486/10000 (95%)
...@@ -116,7 +116,7 @@ Pruning Model ...@@ -116,7 +116,7 @@ Pruning Model
Using L1NormPruner pruning the model and generating the masks. Using L1NormPruner pruning the model and generating the masks.
Usually, pruners require original model and ``config_list`` as parameters. Usually, pruners require original model and ``config_list`` as parameters.
Detailed about how to write ``config_list`` please refer ... Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned, This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,
except the layer named `fc3`, because `fc3` is `exclude`. except the layer named `fc3`, because `fc3` is `exclude`.
...@@ -308,7 +308,7 @@ Because speed up will replace the masked big layers with dense small ones. ...@@ -308,7 +308,7 @@ Because speed up will replace the masked big layers with dense small ones.
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 38.705 seconds) **Total running time of the script:** ( 1 minutes 33.096 seconds)
.. _sphx_glr_download_tutorials_pruning_quick_start_mnist.py: .. _sphx_glr_download_tutorials_pruning_quick_start_mnist.py:
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Quantizing Model\n\nInitialize a `config_list`.\n\n" "## Quantizing Model\n\nInitialize a `config_list`.\nDetailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.\n\n"
] ]
}, },
{ {
......
...@@ -39,6 +39,7 @@ for epoch in range(3): ...@@ -39,6 +39,7 @@ for epoch in range(3):
# ---------------- # ----------------
# #
# Initialize a `config_list`. # Initialize a `config_list`.
# Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
config_list = [{ config_list = [{
'quant_types': ['input', 'weight'], 'quant_types': ['input', 'weight'],
......
bcaf7880c66acfb20f3e5425730e21de 6907e4d1a88c7f2f64e18e34a105e68e
\ No newline at end of file \ No newline at end of file
...@@ -68,21 +68,22 @@ If you are familiar with defining a model and training in pytorch, you can skip ...@@ -68,21 +68,22 @@ If you are familiar with defining a model and training in pytorch, you can skip
.. code-block:: none .. code-block:: none
Average test loss: 0.4891, Accuracy: 8504/10000 (85%) Average test loss: 0.4877, Accuracy: 8541/10000 (85%)
Average test loss: 0.2644, Accuracy: 9179/10000 (92%) Average test loss: 0.2618, Accuracy: 9191/10000 (92%)
Average test loss: 0.1953, Accuracy: 9414/10000 (94%) Average test loss: 0.1626, Accuracy: 9543/10000 (95%)
.. GENERATED FROM PYTHON SOURCE LINES 38-42 .. GENERATED FROM PYTHON SOURCE LINES 38-43
Quantizing Model Quantizing Model
---------------- ----------------
Initialize a `config_list`. Initialize a `config_list`.
Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
.. GENERATED FROM PYTHON SOURCE LINES 42-61 .. GENERATED FROM PYTHON SOURCE LINES 43-62
.. code-block:: default .. code-block:: default
...@@ -112,11 +113,11 @@ Initialize a `config_list`. ...@@ -112,11 +113,11 @@ Initialize a `config_list`.
.. GENERATED FROM PYTHON SOURCE LINES 62-63 .. GENERATED FROM PYTHON SOURCE LINES 63-64
finetuning the model by using QAT finetuning the model by using QAT
.. GENERATED FROM PYTHON SOURCE LINES 63-71 .. GENERATED FROM PYTHON SOURCE LINES 64-72
.. code-block:: default .. code-block:: default
...@@ -138,18 +139,20 @@ finetuning the model by using QAT ...@@ -138,18 +139,20 @@ finetuning the model by using QAT
.. code-block:: none .. code-block:: none
Average test loss: 0.1421, Accuracy: 9567/10000 (96%) op_names ['relu1'] not found in model
Average test loss: 0.1180, Accuracy: 9621/10000 (96%) op_names ['relu2'] not found in model
Average test loss: 0.1119, Accuracy: 9649/10000 (96%) Average test loss: 0.1739, Accuracy: 9441/10000 (94%)
Average test loss: 0.1078, Accuracy: 9671/10000 (97%)
Average test loss: 0.0991, Accuracy: 9696/10000 (97%)
.. GENERATED FROM PYTHON SOURCE LINES 72-73 .. GENERATED FROM PYTHON SOURCE LINES 73-74
export model and get calibration_config export model and get calibration_config
.. GENERATED FROM PYTHON SOURCE LINES 73-78 .. GENERATED FROM PYTHON SOURCE LINES 74-79
.. code-block:: default .. code-block:: default
...@@ -168,7 +171,7 @@ export model and get calibration_config ...@@ -168,7 +171,7 @@ export model and get calibration_config
.. code-block:: none .. code-block:: none
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0034], device='cuda:0'), 'weight_zero_point': tensor([71.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0020], device='cuda:0'), 'weight_zero_point': tensor([112.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 13.904684066772461}} calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0034], device='cuda:0'), 'weight_zero_point': tensor([75.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0018], device='cuda:0'), 'weight_zero_point': tensor([110.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 13.838628768920898}}
...@@ -176,7 +179,7 @@ export model and get calibration_config ...@@ -176,7 +179,7 @@ export model and get calibration_config
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 25.558 seconds) **Total running time of the script:** ( 1 minutes 51.644 seconds)
.. _sphx_glr_download_tutorials_quantization_quick_start_mnist.py: .. _sphx_glr_download_tutorials_quantization_quick_start_mnist.py:
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
Computation times Computation times
================= =================
**01:48.564** total execution time for **tutorials** files: **03:24.740** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 01:38.705 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:51.644 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_speed_up.py` (``pruning_speed_up.py``) | 00:09.859 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 01:33.096 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
...@@ -20,9 +20,9 @@ Computation times ...@@ -20,9 +20,9 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_customize.py` (``pruning_customize.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_customize.py` (``pruning_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_speed_up.py` (``pruning_speed_up.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speed_up.py` (``quantization_speed_up.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_speed_up.py` (``quantization_speed_up.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
...@@ -50,7 +50,7 @@ for epoch in range(3): ...@@ -50,7 +50,7 @@ for epoch in range(3):
# #
# Using L1NormPruner pruning the model and generating the masks. # Using L1NormPruner pruning the model and generating the masks.
# Usually, pruners require original model and ``config_list`` as parameters. # Usually, pruners require original model and ``config_list`` as parameters.
# Detailed about how to write ``config_list`` please refer ... # Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
# #
# This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned, # This `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,
# except the layer named `fc3`, because `fc3` is `exclude`. # except the layer named `fc3`, because `fc3` is `exclude`.
......
...@@ -39,6 +39,7 @@ for epoch in range(3): ...@@ -39,6 +39,7 @@ for epoch in range(3):
# ---------------- # ----------------
# #
# Initialize a `config_list`. # Initialize a `config_list`.
# Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
config_list = [{ config_list = [{
'quant_types': ['input', 'weight'], 'quant_types': ['input', 'weight'],
......
...@@ -35,17 +35,16 @@ def _setattr(model: Module, name: str, module: Module): ...@@ -35,17 +35,16 @@ def _setattr(model: Module, name: str, module: Module):
class Compressor: class Compressor:
""" """
The abstract base pytorch compressor. The abstract base pytorch compressor.
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
""" """
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]): def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
"""
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
self.is_wrapped = False self.is_wrapped = False
if model is not None: if model is not None:
self.reset(model=model, config_list=config_list) self.reset(model=model, config_list=config_list)
......
...@@ -16,21 +16,22 @@ __all__ = ['Pruner'] ...@@ -16,21 +16,22 @@ __all__ = ['Pruner']
class PrunerModuleWrapper(Module): class PrunerModuleWrapper(Module):
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor): """
""" Wrap a module to enable data parallel, forward method customization and buffer registeration.
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
Parameters def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
super().__init__() super().__init__()
# origin layer information # origin layer information
self.module = module self.module = module
......
...@@ -22,15 +22,14 @@ _logger = logging.getLogger(__name__) ...@@ -22,15 +22,14 @@ _logger = logging.getLogger(__name__)
class DataCollector: class DataCollector:
""" """
An abstract class for collect the data needed by the compressor. An abstract class for collect the data needed by the compressor.
Parameters
----------
compressor
The compressor binded with this DataCollector.
""" """
def __init__(self, compressor: Compressor): def __init__(self, compressor: Compressor):
"""
Parameters
----------
compressor
The compressor binded with this DataCollector.
"""
self.compressor = compressor self.compressor = compressor
def reset(self): def reset(self):
...@@ -242,42 +241,43 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -242,42 +241,43 @@ class TrainerBasedDataCollector(DataCollector):
class MetricsCalculator: class MetricsCalculator:
""" """
An abstract class for calculate a kind of metrics of the given data. An abstract class for calculate a kind of metrics of the given data.
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example: Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2]. Example:
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric. If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the data has size (32, 16, 3, 3). Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric. Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2). Then the data has size (32, 16, 3, 3).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0. Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`. So in this case, `dim=0` will set in `__init__`.
In both of these two case, the metric of this module has size (32,). Case 2: Use the output of the conv module as data to calculate the metric.
block_sparse_size Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)). Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim. So in this case, `dim=1` will set in `__init__`.
Example: In both of these two case, the metric of this module has size (32,).
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768], block_sparse_size
then you can set block_sparse_size=[64]. The final metric size is (12,). This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
""" Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768],
then you can set block_sparse_size=[64]. The final metric size is (12,).
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
self.dim = dim if not isinstance(dim, int) else [dim] self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size] self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
if self.block_sparse_size is not None: if self.block_sparse_size is not None:
...@@ -307,36 +307,35 @@ class MetricsCalculator: ...@@ -307,36 +307,35 @@ class MetricsCalculator:
class SparsityAllocator: class SparsityAllocator:
""" """
An abstract class for allocate mask based on metrics. An abstract class for allocate mask based on metrics.
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
continuous_mask
Inherit the mask already in the wrapper if set True.
""" """
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None, def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True): block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
"""
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
continuous_mask
Inherit the mask already in the wrapper if set True.
"""
self.pruner = pruner self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim] self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size] self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
......
...@@ -200,6 +200,17 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_ ...@@ -200,6 +200,17 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
The compact model is the origin model after pruning, The compact model is the origin model after pruning,
and it may have different structure with origin_model cause of speed up. and it may have different structure with origin_model cause of speed up.
Parameters
----------
origin_model : torch.nn.Module
The original un-pruned model.
compact_model : torch.nn.Module
The model after speed up or original model.
compact_model_masks: Dict[str, Dict[str, Tensor]]
The masks applied on the compact model, if the original model have been speed up, this should be {}.
config_list : List[Dict]
The config_list used by pruning the original model.
Returns Returns
------- -------
Tuple[List[Dict], List[Dict], List[Dict]] Tuple[List[Dict], List[Dict], List[Dict]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment