"tests/vscode:/vscode.git/clone" did not exist on "e9636216496240e0e0174631de46d107562e9215"
Unverified Commit b2c31ca2 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] Transformer pruning example (#5017)

parent 3eca23d5
Best Practices
==============
.. toctree::
:hidden:
:maxdepth: 2
Pruning Transformer </tutorials/pruning_bert_glue>
...@@ -9,3 +9,4 @@ Pruning ...@@ -9,3 +9,4 @@ Pruning
Quickstart </tutorials/pruning_quick_start_mnist> Quickstart </tutorials/pruning_quick_start_mnist>
Pruner <pruner> Pruner <pruner>
Speedup </tutorials/pruning_speedup> Speedup </tutorials/pruning_speedup>
Best Practices <best_practices>
...@@ -74,3 +74,11 @@ More examples can be found in our :githublink:`GitHub repository <examples>`. ...@@ -74,3 +74,11 @@ More examples can be found in our :githublink:`GitHub repository <examples>`.
:image: ../img/thumbnails/quantization-speed-up.svg :image: ../img/thumbnails/quantization-speed-up.svg
:background: indigo :background: indigo
:tags: Compression :tags: Compression
.. cardlinkitem::
:header: Pruning Bert on Task MNLI
:description: An end to end example for how to using NNI pruning transformer and show the real speedup number
:link: tutorials/pruning_bert_glue
:image: ../img/thumbnails/pruning-tutorial.svg
:background: indigo
:tags: Compression
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
</div>
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/main
/tutorials/hpo_quickstart_pytorch/model
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
</div>
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/main
/tutorials/hpo_quickstart_tensorflow/model
:orphan: :orphan:
Tutorials
=========
.. _sphx_glr_tutorials:
Tutorials .. raw:: html
=========
<div class="sphx-glr-thumbnails">
.. raw:: html .. raw:: html
...@@ -15,157 +16,152 @@ Tutorials ...@@ -15,157 +16,152 @@ Tutorials
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt: Speedup Model with Mask :alt: Speedup Model with Mask
:ref:`sphx_glr_tutorials_pruning_speedup.py` :ref:`sphx_glr_tutorials_pruning_speedup.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Speedup Model with Mask</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/pruning_speedup
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------"> <div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt: SpeedUp Model with Calibration Config :alt: SpeedUp Model with Calibration Config
:ref:`sphx_glr_tutorials_quantization_speedup.py` :ref:`sphx_glr_tutorials_quantization_speedup.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">SpeedUp Model with Calibration Config</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/quantization_speedup
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization."> <div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt: Quantization Quickstart :alt: Quantization Quickstart
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Quantization Quickstart</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/quantization_quick_start_mnist
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning."> <div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt: Pruning Quickstart :alt: Pruning Quickstart
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Quickstart</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/pruning_quick_start_mnist
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto..."> <div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt: Customize a new quantization algorithm :alt: Customize a new quantization algorithm
:ref:`sphx_glr_tutorials_quantization_customize.py` :ref:`sphx_glr_tutorials_quantization_customize.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Customize a new quantization algorithm</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/quantization_customize
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet..."> <div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt: Use NAS Benchmarks as Datasets :alt: Use NAS Benchmarks as Datasets
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py` :ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Use NAS Benchmarks as Datasets</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/nasbench_as_dataset
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro..."> <div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
:alt: Customize Basic Pruner :alt: Customize Basic Pruner
:ref:`sphx_glr_tutorials_pruning_customize.py` :ref:`sphx_glr_tutorials_pruning_customize.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Customize Basic Pruner</div>
</div> </div>
.. toctree:: .. raw:: html
:hidden:
/tutorials/pruning_customize <div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt: Hello, NAS!
:ref:`sphx_glr_tutorials_hello_nas.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ..."> <div class="sphx-glr-thumbnail-title">Hello, NAS!</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
.. only:: html .. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Hello, NAS! :alt: Pruning Transformer with NNI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Transformer with NNI</div>
</div>
:ref:`sphx_glr_tutorials_hello_nas.py`
.. raw:: html .. raw:: html
...@@ -175,16 +171,22 @@ Tutorials ...@@ -175,16 +171,22 @@ Tutorials
.. toctree:: .. toctree::
:hidden: :hidden:
/tutorials/pruning_speedup
/tutorials/quantization_speedup
/tutorials/quantization_quick_start_mnist
/tutorials/pruning_quick_start_mnist
/tutorials/quantization_customize
/tutorials/nasbench_as_dataset
/tutorials/pruning_customize
/tutorials/hello_nas /tutorials/hello_nas
.. raw:: html /tutorials/pruning_bert_glue
<div class="sphx-glr-clear"></div>
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html .. raw:: html
...@@ -193,50 +195,44 @@ Tutorials ...@@ -193,50 +195,44 @@ Tutorials
.. only:: html .. only:: html
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch :alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py` :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/main
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version."> <div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html .. only:: html
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI :alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py` :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/model
.. raw:: html .. raw:: html
<div class="sphx-glr-clear"></div> </div>
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html .. raw:: html
...@@ -245,31 +241,33 @@ Tutorials ...@@ -245,31 +241,33 @@ Tutorials
.. only:: html .. only:: html
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png .. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow :alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py` :ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
</div> </div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/main
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version."> <div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html .. only:: html
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png .. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI :alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
</div>
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html .. raw:: html
...@@ -278,11 +276,10 @@ Tutorials ...@@ -278,11 +276,10 @@ Tutorials
.. toctree:: .. toctree::
:hidden: :hidden:
:includehidden:
/tutorials/hpo_quickstart_tensorflow/model /tutorials/hpo_quickstart_pytorch/index.rst
.. raw:: html /tutorials/hpo_quickstart_tensorflow/index.rst
<div class="sphx-glr-clear"></div>
......
This diff is collapsed.
This diff is collapsed.
7d8ff24fe5a88d208ad2ad051f060df4
\ No newline at end of file
This diff is collapsed.
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
Computation times Computation times
================= =================
**01:45.743** total execution time for **tutorials** files: **00:27.206** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:45.743 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
...@@ -22,5 +22,7 @@ Computation times ...@@ -22,5 +22,7 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
...@@ -3,4 +3,6 @@ ...@@ -3,4 +3,6 @@
data/ data/
MNIST/ MNIST/
cifar-10-batches-py/ cifar-10-batches-py/
experiment_data/ experiment_data/
\ No newline at end of file pruning/models
pruning/pruning_log
\ No newline at end of file
data/ data/
log/ log/
*.onnx *.onnx
\ No newline at end of file models/
pruning_log/
\ No newline at end of file
This diff is collapsed.
...@@ -189,7 +189,7 @@ class EvaluatorBasedPruner(BasicPruner): ...@@ -189,7 +189,7 @@ class EvaluatorBasedPruner(BasicPruner):
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'") raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value merged_kwargs[key] = value
for key, value in def_kwargs.items(): for key, value in def_kwargs.items():
if key not in merged_kwargs: if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys()) diff = set(arg_names).difference(merged_kwargs.keys())
if diff: if diff:
...@@ -734,6 +734,8 @@ class ActivationPruner(EvaluatorBasedPruner): ...@@ -734,6 +734,8 @@ class ActivationPruner(EvaluatorBasedPruner):
def _choose_activation(self, activation: str = 'relu') -> Callable: def _choose_activation(self, activation: str = 'relu') -> Callable:
if activation == 'relu': if activation == 'relu':
return F.relu return F.relu
elif activation == 'gelu':
return F.gelu
elif activation == 'relu6': elif activation == 'relu6':
return F.relu6 return F.relu6
else: else:
......
...@@ -60,7 +60,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler): ...@@ -60,7 +60,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'") raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value merged_kwargs[key] = value
for key, value in def_kwargs.items(): for key, value in def_kwargs.items():
if key not in merged_kwargs: if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys()) diff = set(arg_names).difference(merged_kwargs.keys())
if diff: if diff:
......
...@@ -6,6 +6,7 @@ from __future__ import annotations ...@@ -6,6 +6,7 @@ from __future__ import annotations
from copy import deepcopy from copy import deepcopy
import logging import logging
from typing import Dict, List, Tuple, Callable, overload from typing import Dict, List, Tuple, Callable, overload
from typing_extensions import Literal
import torch import torch
from torch import autograd, Tensor from torch import autograd, Tensor
...@@ -21,15 +22,18 @@ from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector ...@@ -21,15 +22,18 @@ from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from .tools import ( from .tools import (
NormalSparsityAllocator, NormalSparsityAllocator,
ThresholdSparsityAllocator,
StraightMetricsCalculator StraightMetricsCalculator
) )
from ..utils import ( from ..utils import (
LightningEvaluator, LightningEvaluator,
TorchEvaluator TorchEvaluator,
Scaling
) )
from ..utils.docstring import _EVALUATOR_DOCSTRING from ..utils.docstring import _EVALUATOR_DOCSTRING
from ..utils.external.huggingface import parser_factory
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -48,14 +52,18 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper): ...@@ -48,14 +52,18 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
module_name module_name
The name of the module to compress, wrapper module shares same name. The name of the module to compress, wrapper module shares same name.
""" """
def __init__(self, module: Module, module_name: str, config: Dict): def __init__(self, module: Module, module_name: str, config: Dict, score_size: List[int] | None = None):
super().__init__(module, module_name, config) super().__init__(module, module_name, config)
self.weight_score = Parameter(torch.empty(self.weight.size())) # type: ignore self.weight_score = Parameter(torch.empty(score_size)) \
if score_size is not None else Parameter(torch.empty_like(module.weight)) # type: ignore
torch.nn.init.constant_(self.weight_score, val=0.0) torch.nn.init.constant_(self.weight_score, val=0.0)
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias repeat = [a // b for a, b in zip(self.weight.shape, self.weight_score.shape)] # type: ignore
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore weight_score = self.weight_score
for dim, num in enumerate(repeat):
weight_score = weight_score.repeat_interleave(num, dim=dim)
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(weight_score, self.weight_mask)) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs) return self.module(*inputs)
...@@ -124,9 +132,9 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -124,9 +132,9 @@ class MovementPruner(EvaluatorBasedPruner):
Parameters Parameters
---------- ----------
model : torch.nn.Module model
Model to be pruned. Model to be pruned.
config_list : List[Dict] config_list
Supported keys: Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed. - sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity. - sparsity_per_layer : Equals to sparsity.
...@@ -140,16 +148,39 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -140,16 +148,39 @@ class MovementPruner(EvaluatorBasedPruner):
{evaluator_docstring} {evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0. The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__. If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs : int warm_up_step
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
warm_up_step : int
The total `optimizer.step()` number before start pruning for warm up. The total `optimizer.step()` number before start pruning for warm up.
Make sure `warm_up_step` is smaller than `cool_down_beginning_step`. Make sure ``warm_up_step`` is smaller than ``cool_down_beginning_step``.
cool_down_beginning_step: int cool_down_beginning_step
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed. The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
The sparsity after each `optimizer.step()` is: The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3). total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
training_epochs
The total epoch number for training the model.
Make sure the total `optimizer.step()` in ``training_epochs`` is bigger than `cool_down_beginning_step`.
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
training_steps
The total step number for training the model.
Make sure ``training_epochs`` is bigger than ``cool_down_beginning_step``.
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
regular_scale
Use to scale the movement score regular loss. In 'soft' mode, higher regular scale means higher final sparsity.
The recommended range is 1 ~ 30.
movement_mode
'hard' or 'soft'. Note that in 'soft' mode, ``sparsity`` set in the ``config_list`` means the sparsify threshold,
'soft' mode cannot precisely control the sparsity rate, but usually has higher performance compared with 'hard' mode.
``sparsity`` in 'soft' mode usually set to ``0.1``, and using ``regular_scale`` to control the final relative sparsity.
For detailed differences between 'hard' and 'soft', please refer to the paper.
In short, 'hard' means that the corresponding layer is pruned to a fixed ratio by the topk method according to the movement score,
which is the sparsity ratio set in config_list.
'soft' means that the final sparsity size will not be fixed, but the generation of the mask will be controlled by a threshold,
and the positions corresponding to scores below the threshold will be masked during the movement training process.
sparse_granularity
This is an experimental interface, by default, apply 'finegrained' pruning. If 'auto' is set, will try to apply structure pruning.
For the attention layer, will apply block sparse with size [head_width, head_width]. For the following two linear layers (FFN),
will apply output channel pruning for the first linear, and the input channel pruning for the second one.
'auto' only support partial hugingface transformers right now (bart, bert, t5).
Notes Notes
----- -----
...@@ -157,8 +188,10 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -157,8 +188,10 @@ class MovementPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int, def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, warm_up_step: int,
warm_up_step: int, cool_down_beginning_step: int): cool_down_beginning_step: int, training_epochs: int | None = None, training_steps: int | None = None,
regular_scale: float | None = None, movement_mode: Literal['hard', 'soft'] = 'hard',
sparse_granularity: Literal['auto', 'finegrained'] = 'finegrained'):
... ...
@overload @overload
...@@ -169,14 +202,23 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -169,14 +202,23 @@ class MovementPruner(EvaluatorBasedPruner):
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs): def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload. # TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step'] new_api = ['evaluator', 'warm_up_step', 'cool_down_beginning_step', 'training_epochs', 'training_steps', 'regular_scale',
'movement_mode', 'sparse_granularity']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step'] old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs) init_kwargs = {'training_epochs': None, 'training_steps': None, 'regular_scale': None, 'movement_mode': 'hard',
'sparse_granularity': 'finegrained'}
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
self.training_epochs: int = init_kwargs['training_epochs'] self.training_epochs: int = init_kwargs['training_epochs']
self.training_steps: int | None = init_kwargs['training_steps'] if self.using_evaluator else None
self.warm_up_step: int = init_kwargs['warm_up_step'] self.warm_up_step: int = init_kwargs['warm_up_step']
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step'] self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
self.regular_scale: int | None = init_kwargs['regular_scale'] if self.using_evaluator else None
self.movement_mode: Literal['hard', 'soft'] | None = init_kwargs['movement_mode'] if self.using_evaluator else None
self.sparse_granularity = init_kwargs['sparse_granularity'] if self.using_evaluator else None
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`' assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
self._model_parser = parser_factory(model)
super().__init__(model, config_list) super().__init__(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
...@@ -185,20 +227,61 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -185,20 +227,61 @@ class MovementPruner(EvaluatorBasedPruner):
schema.validate(config_list) schema.validate(config_list)
def cubic_schedule(self, current_step: int): def cubic_schedule(self, current_step: int):
if self.warm_up_step < current_step <= self.cool_down_beginning_step: wrapper_dict = self.get_modules_wrapper()
wrapper_dict = self.get_modules_wrapper() for config in self.config_list:
for config in self.config_list: current_sparsity = config['total_sparsity'] * self._cubic_scale(current_step)
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3 for op_name in config['op_names']:
current_sparsity = config['total_sparsity'] * scale # There is an unreachable pyright error if `wrapper_dict[op_name].config['total_sparsity'] = current_sparsity`,
for op_name in config['op_names']: # seems a pyright bug...
wrapper = wrapper_dict[op_name] wrapper_config = wrapper_dict[op_name].config
wrapper.config['total_sparsity'] = current_sparsity wrapper_config['total_sparsity'] = current_sparsity
def _cubic_scale(self, current_step: int):
if self.warm_up_step > current_step:
return 0
elif current_step > self.cool_down_beginning_step:
return 1
else:
return 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
def _create_scalers(self) -> Scaling | Dict[str, Dict[str, Scaling]]:
assert self.bound_model is not None
if self.sparse_granularity and self.sparse_granularity == 'auto' and self._model_parser:
scalers = {}
for module_name, wrapper in self.get_modules_wrapper().items():
if self._model_parser.is_attention(module_name):
num_heads = self._model_parser.get_num_heads(module_name, self.bound_model)
if num_heads <= 0:
scalers[module_name] = {'_default': Scaling([1])}
else:
# assume attention layer weights are 2D
weight_h: int = wrapper.module.weight.shape[0] # type: ignore
weight_w: int = wrapper.module.weight.shape[1] # type: ignore
if weight_h % num_heads != 0 or weight_w % num_heads != 0:
scalers[module_name] = {'_default': Scaling([1])}
else:
block_h = weight_h // num_heads
block_w = weight_w // num_heads
scalers[module_name] = {'_default': Scaling([block_h, block_w])}
elif self._model_parser.is_ffn(module_name, ffn_num=1):
scalers[module_name] = {'_default': Scaling([1, wrapper.module.weight.shape[1]])} # type: ignore
elif self._model_parser.is_ffn(module_name, ffn_num=2):
scalers[module_name] = {'_default': Scaling([wrapper.module.weight.shape[0], 1])} # type: ignore
else:
scalers[module_name] = {'_default': Scaling([1])}
else:
scalers = Scaling([1])
return scalers
def reset_tools(self): def reset_tools(self):
scalers = self._create_scalers()
if not hasattr(self, 'metrics_calculator'): if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = StraightMetricsCalculator() self.metrics_calculator = StraightMetricsCalculator()
if not hasattr(self, 'sparsity_allocator'): if not hasattr(self, 'sparsity_allocator'):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False) if self.movement_mode == 'soft':
self.sparsity_allocator = ThresholdSparsityAllocator(self, scalers=scalers, continuous_mask=False)
else:
self.sparsity_allocator = NormalSparsityAllocator(self, scalers=scalers, continuous_mask=False)
# use Adam to update the weight_score # use Adam to update the weight_score
assert self.bound_model is not None assert self.bound_model is not None
...@@ -206,6 +289,14 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -206,6 +289,14 @@ class MovementPruner(EvaluatorBasedPruner):
optimizer = Adam(params, 1e-2) optimizer = Adam(params, 1e-2)
self.step_counter = 0 self.step_counter = 0
# TODO: waiting for api stable and experiemnts to prove this scheduler is needed.
# def lr_lambda(current_step: int):
# if current_step < self.warm_up_step:
# return float(current_step) / self.warm_up_step
# return max(0.0, float(147264 - current_step) / float(147264 - self.warm_up_step))
# lr_scheduler = LambdaLR(optimizer, lr_lambda)
# update the masks after each optimzier step # update the masks after each optimzier step
def _optimizer_patch(): def _optimizer_patch():
optimizer.step() optimizer.step()
...@@ -221,6 +312,17 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -221,6 +312,17 @@ class MovementPruner(EvaluatorBasedPruner):
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks) self.load_masks(masks)
def _loss_patch(origin_loss: Tensor):
if self.regular_scale is not None:
l1_reg = 0
count = 0
for wrapper in self.get_modules_wrapper().values():
l1_reg += torch.norm(torch.sigmoid(wrapper.weight_score), p=1) / wrapper.weight_score.numel() # type: ignore
count += 1
return origin_loss + self.regular_scale * self._cubic_scale(self.step_counter) * l1_reg / count
else:
return origin_loss
if self.using_evaluator: if self.using_evaluator:
# TODO: move to other place in nni v3.0 # TODO: move to other place in nni v3.0
self.evaluator.unbind_model() self.evaluator.unbind_model()
...@@ -228,7 +330,9 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -228,7 +330,9 @@ class MovementPruner(EvaluatorBasedPruner):
if not hasattr(self, 'data_collector'): if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator, self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
after_opt_step_tasks=[_optimizer_patch], after_opt_step_tasks=[_optimizer_patch],
max_epochs=self.training_epochs) max_epochs=self.training_epochs,
max_steps=self.training_steps,
loss_patch=_loss_patch)
else: else:
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch]) self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
else: else:
...@@ -252,7 +356,27 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -252,7 +356,27 @@ class MovementPruner(EvaluatorBasedPruner):
The configuration for generating the mask. The configuration for generating the mask.
""" """
_logger.debug("Module detected to compress : %s.", layer.name) _logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config) assert self.bound_model is not None
# TODO: merge with _create_scalers after nni v3.0
if self.sparse_granularity and self.sparse_granularity == 'auto' and self._model_parser:
if self._model_parser.is_attention(layer.name):
num_heads = self._model_parser.get_num_heads(layer.name, self.bound_model)
if num_heads <= 0:
score_size = None
else:
if layer.module.weight.shape[0] % num_heads != 0 or layer.module.weight.shape[1] % num_heads != 0: # type: ignore
score_size = None
else:
score_size = [num_heads, num_heads]
elif self._model_parser.is_ffn(layer.name, ffn_num=1):
score_size = [layer.module.weight.shape[0], 1] # type: ignore
elif self._model_parser.is_ffn(layer.name, ffn_num=2):
score_size = [1, layer.module.weight.shape[1]] # type: ignore
else:
score_size = None
else:
score_size = None
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config, score_size)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight # move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) # type: ignore wrapper.to(layer.module.weight.device) # type: ignore
......
...@@ -29,6 +29,7 @@ from .metrics_calculator import ( ...@@ -29,6 +29,7 @@ from .metrics_calculator import (
) )
from .sparsity_allocator import ( from .sparsity_allocator import (
NormalSparsityAllocator, NormalSparsityAllocator,
ThresholdSparsityAllocator,
BankSparsityAllocator, BankSparsityAllocator,
GlobalSparsityAllocator, GlobalSparsityAllocator,
DependencyAwareAllocator DependencyAwareAllocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment