Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b2c31ca2
Unverified
Commit
b2c31ca2
authored
Aug 16, 2022
by
J-shang
Committed by
GitHub
Aug 16, 2022
Browse files
[Compression] Transformer pruning example (#5017)
parent
3eca23d5
Changes
32
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2551 additions
and
131 deletions
+2551
-131
docs/source/compression/best_practices.rst
docs/source/compression/best_practices.rst
+8
-0
docs/source/compression/toctree_pruning.rst
docs/source/compression/toctree_pruning.rst
+1
-0
docs/source/examples.rst
docs/source/examples.rst
+8
-0
docs/source/tutorials/hpo_quickstart_pytorch/index.rst
docs/source/tutorials/hpo_quickstart_pytorch/index.rst
+57
-0
docs/source/tutorials/hpo_quickstart_tensorflow/index.rst
docs/source/tutorials/hpo_quickstart_tensorflow/index.rst
+57
-0
docs/source/tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
...torials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
+0
-0
docs/source/tutorials/index.rst
docs/source/tutorials/index.rst
+94
-97
docs/source/tutorials/pruning_bert_glue.ipynb
docs/source/tutorials/pruning_bert_glue.ipynb
+223
-0
docs/source/tutorials/pruning_bert_glue.py
docs/source/tutorials/pruning_bert_glue.py
+563
-0
docs/source/tutorials/pruning_bert_glue.py.md5
docs/source/tutorials/pruning_bert_glue.py.md5
+1
-0
docs/source/tutorials/pruning_bert_glue.rst
docs/source/tutorials/pruning_bert_glue.rst
+809
-0
docs/source/tutorials/pruning_bert_glue_codeobj.pickle
docs/source/tutorials/pruning_bert_glue_codeobj.pickle
+0
-0
docs/source/tutorials/sg_execution_times.rst
docs/source/tutorials/sg_execution_times.rst
+4
-2
examples/model_compress/.gitignore
examples/model_compress/.gitignore
+3
-1
examples/tutorials/.gitignore
examples/tutorials/.gitignore
+3
-1
examples/tutorials/pruning_bert_glue.py
examples/tutorials/pruning_bert_glue.py
+563
-0
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+3
-1
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
...orithms/compression/v2/pytorch/pruning/basic_scheduler.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
...orithms/compression/v2/pytorch/pruning/movement_pruner.py
+152
-28
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
...gorithms/compression/v2/pytorch/pruning/tools/__init__.py
+1
-0
No files found.
docs/source/compression/best_practices.rst
0 → 100644
View file @
b2c31ca2
Best Practices
==============
.. toctree::
:hidden:
:maxdepth: 2
Pruning Transformer </tutorials/pruning_bert_glue>
docs/source/compression/toctree_pruning.rst
View file @
b2c31ca2
...
...
@@ -9,3 +9,4 @@ Pruning
Quickstart </tutorials/pruning_quick_start_mnist>
Pruner <pruner>
Speedup </tutorials/pruning_speedup>
Best Practices <best_practices>
docs/source/examples.rst
View file @
b2c31ca2
...
...
@@ -74,3 +74,11 @@ More examples can be found in our :githublink:`GitHub repository <examples>`.
:image: ../img/thumbnails/quantization-speed-up.svg
:background: indigo
:tags: Compression
.. cardlinkitem::
:header: Pruning Bert on Task MNLI
:description: An end to end example for how to using NNI pruning transformer and show the real speedup number
:link: tutorials/pruning_bert_glue
:image: ../img/thumbnails/pruning-tutorial.svg
:background: indigo
:tags: Compression
docs/source/tutorials/hpo_quickstart_pytorch/index.rst
0 → 100644
View file @
b2c31ca2
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
</div>
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/main
/tutorials/hpo_quickstart_pytorch/model
docs/source/tutorials/hpo_quickstart_tensorflow/index.rst
0 → 100644
View file @
b2c31ca2
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
</div>
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/main
/tutorials/hpo_quickstart_tensorflow/model
docs/source/tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
0 → 100644
View file @
b2c31ca2
34.6 KB
docs/source/tutorials/index.rst
View file @
b2c31ca2
:orphan:
Tutorials
=========
.. _sphx_glr_tutorials:
Tutorials
=========
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
...
...
@@ -15,157 +16,152 @@ Tutorials
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt: Speedup Model with Mask
..
imag
e:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt: Speedup Model with Mask
:ref:`sphx_glr_tutorials_pruning_speedup.py`
:ref:`sphx_glr_tutorials_pruning_speedup.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Speedup Model with Mask</div>
</div>
.. toctree::
:hidden:
/tutorials/pruning_speedup
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt: SpeedUp Model with Calibration Config
..
imag
e:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt: SpeedUp Model with Calibration Config
:ref:`sphx_glr_tutorials_quantization_speedup.py`
:ref:`sphx_glr_tutorials_quantization_speedup.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">SpeedUp Model with Calibration Config</div>
</div>
.. toctree::
:hidden:
/tutorials/quantization_speedup
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt: Quantization Quickstart
..
imag
e:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt: Quantization Quickstart
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Quantization Quickstart</div>
</div>
.. toctree::
:hidden:
/tutorials/quantization_quick_start_mnist
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt: Pruning Quickstart
..
imag
e:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt: Pruning Quickstart
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Quickstart</div>
</div>
.. toctree::
:hidden:
/tutorials/pruning_quick_start_mnist
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt: Customize a new quantization algorithm
..
imag
e:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt: Customize a new quantization algorithm
:ref:`sphx_glr_tutorials_quantization_customize.py`
:ref:`sphx_glr_tutorials_quantization_customize.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Customize a new quantization algorithm</div>
</div>
.. toctree::
:hidden:
/tutorials/quantization_customize
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt: Use NAS Benchmarks as Datasets
..
imag
e:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt: Use NAS Benchmarks as Datasets
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Use NAS Benchmarks as Datasets</div>
</div>
.. toctree::
:hidden:
/tutorials/nasbench_as_dataset
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
.. only:: html
..
figur
e:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
:alt: Customize Basic Pruner
..
imag
e:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
:alt: Customize Basic Pruner
:ref:`sphx_glr_tutorials_pruning_customize.py`
:ref:`sphx_glr_tutorials_pruning_customize.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Customize Basic Pruner</div>
</div>
.. toctree::
:hidden:
.. raw:: html
/tutorials/pruning_customize
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt: Hello, NAS!
:ref:`sphx_glr_tutorials_hello_nas.py`
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
<div class="sphx-glr-thumbnail-title">Hello, NAS!</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt: Hello, NAS!
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Pruning Transformer with NNI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Transformer with NNI</div>
</div>
:ref:`sphx_glr_tutorials_hello_nas.py`
.. raw:: html
...
...
@@ -175,16 +171,22 @@ Tutorials
.. toctree::
:hidden:
/tutorials/pruning_speedup
/tutorials/quantization_speedup
/tutorials/quantization_quick_start_mnist
/tutorials/pruning_quick_start_mnist
/tutorials/quantization_customize
/tutorials/nasbench_as_dataset
/tutorials/pruning_customize
/tutorials/hello_nas
.. raw:: html
<div class="sphx-glr-clear"></div>
/tutorials/pruning_bert_glue
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
...
...
@@ -193,50 +195,44 @@ Tutorials
.. only:: html
..
figur
e:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
..
imag
e:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/main
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
..
figur
e:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
..
imag
e:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/model
.. raw:: html
<div class="sphx-glr-clear">
</div>
</div>
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
.. raw:: html
<div class="sphx-glr-thumbnails">
.. raw:: html
...
...
@@ -245,31 +241,33 @@ Tutorials
.. only:: html
..
figur
e:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
..
imag
e:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/main
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
</div>
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
...
...
@@ -278,11 +276,10 @@ Tutorials
.. toctree::
:hidden:
:includehidden:
/tutorials/hpo_quickstart_tensorflow/model
.. raw:: html
<div class="sphx-glr-clear"></div>
/tutorials/hpo_quickstart_pytorch/index.rst
/tutorials/hpo_quickstart_tensorflow/index.rst
...
...
docs/source/tutorials/pruning_bert_glue.ipynb
0 → 100644
View file @
b2c31ca2
This diff is collapsed.
Click to expand it.
docs/source/tutorials/pruning_bert_glue.py
0 → 100644
View file @
b2c31ca2
This diff is collapsed.
Click to expand it.
docs/source/tutorials/pruning_bert_glue.py.md5
0 → 100644
View file @
b2c31ca2
7d8ff24fe5a88d208ad2ad051f060df4
\ No newline at end of file
docs/source/tutorials/pruning_bert_glue.rst
0 → 100644
View file @
b2c31ca2
This diff is collapsed.
Click to expand it.
docs/source/tutorials/pruning_bert_glue_codeobj.pickle
0 → 100644
View file @
b2c31ca2
File added
docs/source/tutorials/sg_execution_times.rst
View file @
b2c31ca2
...
...
@@ -5,10 +5,10 @@
Computation times
=================
**0
1:45.743
** total execution time for **tutorials** files:
**0
0:27.206
** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_
quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:45.743
| 0.0 MB |
| :ref:`sphx_glr_tutorials_
pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206
| 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
...
...
@@ -22,5 +22,7 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
examples/model_compress/.gitignore
View file @
b2c31ca2
...
...
@@ -3,4 +3,6 @@
data/
MNIST/
cifar-10-batches-py/
experiment_data/
\ No newline at end of file
experiment_data/
pruning/models
pruning/pruning_log
\ No newline at end of file
examples/tutorials/.gitignore
View file @
b2c31ca2
data/
log/
*.onnx
\ No newline at end of file
*.onnx
models/
pruning_log/
\ No newline at end of file
examples/tutorials/pruning_bert_glue.py
0 → 100644
View file @
b2c31ca2
This diff is collapsed.
Click to expand it.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
b2c31ca2
...
...
@@ -189,7 +189,7 @@ class EvaluatorBasedPruner(BasicPruner):
raise
TypeError
(
f
"
{
self
.
__class__
.
__name__
}
.__init__() got multiple values for argument '
{
key
}
'"
)
merged_kwargs
[
key
]
=
value
for
key
,
value
in
def_kwargs
.
items
():
if
key
not
in
merged_kwargs
:
if
key
not
in
merged_kwargs
and
key
in
arg_names
:
merged_kwargs
[
key
]
=
value
diff
=
set
(
arg_names
).
difference
(
merged_kwargs
.
keys
())
if
diff
:
...
...
@@ -734,6 +734,8 @@ class ActivationPruner(EvaluatorBasedPruner):
def
_choose_activation
(
self
,
activation
:
str
=
'relu'
)
->
Callable
:
if
activation
==
'relu'
:
return
F
.
relu
elif
activation
==
'gelu'
:
return
F
.
gelu
elif
activation
==
'relu6'
:
return
F
.
relu6
else
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
View file @
b2c31ca2
...
...
@@ -60,7 +60,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
raise
TypeError
(
f
"
{
self
.
__class__
.
__name__
}
.__init__() got multiple values for argument '
{
key
}
'"
)
merged_kwargs
[
key
]
=
value
for
key
,
value
in
def_kwargs
.
items
():
if
key
not
in
merged_kwargs
:
if
key
not
in
merged_kwargs
and
key
in
arg_names
:
merged_kwargs
[
key
]
=
value
diff
=
set
(
arg_names
).
difference
(
merged_kwargs
.
keys
())
if
diff
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
View file @
b2c31ca2
...
...
@@ -6,6 +6,7 @@ from __future__ import annotations
from
copy
import
deepcopy
import
logging
from
typing
import
Dict
,
List
,
Tuple
,
Callable
,
overload
from
typing_extensions
import
Literal
import
torch
from
torch
import
autograd
,
Tensor
...
...
@@ -21,15 +22,18 @@ from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from
.tools
import
(
NormalSparsityAllocator
,
ThresholdSparsityAllocator
,
StraightMetricsCalculator
)
from
..utils
import
(
LightningEvaluator
,
TorchEvaluator
TorchEvaluator
,
Scaling
)
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
from
..utils.external.huggingface
import
parser_factory
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -48,14 +52,18 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
module_name
The name of the module to compress, wrapper module shares same name.
"""
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
):
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
score_size
:
List
[
int
]
|
None
=
None
):
super
().
__init__
(
module
,
module_name
,
config
)
self
.
weight_score
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
# type: ignore
self
.
weight_score
=
Parameter
(
torch
.
empty
(
score_size
))
\
if
score_size
is
not
None
else
Parameter
(
torch
.
empty_like
(
module
.
weight
))
# type: ignore
torch
.
nn
.
init
.
constant_
(
self
.
weight_score
,
val
=
0.0
)
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
_StraightThrough
.
apply
(
self
.
weight_score
,
self
.
weight_mask
))
# type: ignore
repeat
=
[
a
//
b
for
a
,
b
in
zip
(
self
.
weight
.
shape
,
self
.
weight_score
.
shape
)]
# type: ignore
weight_score
=
self
.
weight_score
for
dim
,
num
in
enumerate
(
repeat
):
weight_score
=
weight_score
.
repeat_interleave
(
num
,
dim
=
dim
)
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
_StraightThrough
.
apply
(
weight_score
,
self
.
weight_mask
))
# type: ignore
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
module
.
bias
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
# type: ignore
return
self
.
module
(
*
inputs
)
...
...
@@ -124,9 +132,9 @@ class MovementPruner(EvaluatorBasedPruner):
Parameters
----------
model
: torch.nn.Module
model
Model to be pruned.
config_list
: List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
...
...
@@ -140,16 +148,39 @@ class MovementPruner(EvaluatorBasedPruner):
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs : int
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
warm_up_step : int
warm_up_step
The total `optimizer.step()` number before start pruning for warm up.
Make sure `warm_up_step` is smaller than `cool_down_beginning_step`.
cool_down_beginning_step
: int
Make sure
`
`warm_up_step`
`
is smaller than
`
`cool_down_beginning_step`
`
.
cool_down_beginning_step
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
training_epochs
The total epoch number for training the model.
Make sure the total `optimizer.step()` in ``training_epochs`` is bigger than `cool_down_beginning_step`.
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
training_steps
The total step number for training the model.
Make sure ``training_epochs`` is bigger than ``cool_down_beginning_step``.
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
regular_scale
Use to scale the movement score regular loss. In 'soft' mode, higher regular scale means higher final sparsity.
The recommended range is 1 ~ 30.
movement_mode
'hard' or 'soft'. Note that in 'soft' mode, ``sparsity`` set in the ``config_list`` means the sparsify threshold,
'soft' mode cannot precisely control the sparsity rate, but usually has higher performance compared with 'hard' mode.
``sparsity`` in 'soft' mode usually set to ``0.1``, and using ``regular_scale`` to control the final relative sparsity.
For detailed differences between 'hard' and 'soft', please refer to the paper.
In short, 'hard' means that the corresponding layer is pruned to a fixed ratio by the topk method according to the movement score,
which is the sparsity ratio set in config_list.
'soft' means that the final sparsity size will not be fixed, but the generation of the mask will be controlled by a threshold,
and the positions corresponding to scores below the threshold will be masked during the movement training process.
sparse_granularity
This is an experimental interface, by default, apply 'finegrained' pruning. If 'auto' is set, will try to apply structure pruning.
For the attention layer, will apply block sparse with size [head_width, head_width]. For the following two linear layers (FFN),
will apply output channel pruning for the first linear, and the input channel pruning for the second one.
'auto' only support partial hugingface transformers right now (bart, bert, t5).
Notes
-----
...
...
@@ -157,8 +188,10 @@ class MovementPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
TorchEvaluator
,
training_epochs
:
int
,
warm_up_step
:
int
,
cool_down_beginning_step
:
int
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
TorchEvaluator
,
warm_up_step
:
int
,
cool_down_beginning_step
:
int
,
training_epochs
:
int
|
None
=
None
,
training_steps
:
int
|
None
=
None
,
regular_scale
:
float
|
None
=
None
,
movement_mode
:
Literal
[
'hard'
,
'soft'
]
=
'hard'
,
sparse_granularity
:
Literal
[
'auto'
,
'finegrained'
]
=
'finegrained'
):
...
@
overload
...
...
@@ -169,14 +202,23 @@ class MovementPruner(EvaluatorBasedPruner):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
*
args
,
**
kwargs
):
# TODO: remove in nni v3.0. Fake overload.
new_api
=
[
'evaluator'
,
'training_epochs'
,
'warm_up_step'
,
'cool_down_beginning_step'
]
new_api
=
[
'evaluator'
,
'warm_up_step'
,
'cool_down_beginning_step'
,
'training_epochs'
,
'training_steps'
,
'regular_scale'
,
'movement_mode'
,
'sparse_granularity'
]
old_api
=
[
'trainer'
,
'traced_optimizer'
,
'criterion'
,
'training_epochs'
,
'warm_up_step'
,
'cool_down_beginning_step'
]
init_kwargs
=
self
.
_init_evaluator
(
model
,
new_api
,
old_api
,
{},
args
,
kwargs
)
init_kwargs
=
{
'training_epochs'
:
None
,
'training_steps'
:
None
,
'regular_scale'
:
None
,
'movement_mode'
:
'hard'
,
'sparse_granularity'
:
'finegrained'
}
init_kwargs
=
self
.
_init_evaluator
(
model
,
new_api
,
old_api
,
init_kwargs
,
args
,
kwargs
)
self
.
training_epochs
:
int
=
init_kwargs
[
'training_epochs'
]
self
.
training_steps
:
int
|
None
=
init_kwargs
[
'training_steps'
]
if
self
.
using_evaluator
else
None
self
.
warm_up_step
:
int
=
init_kwargs
[
'warm_up_step'
]
self
.
cool_down_beginning_step
:
int
=
init_kwargs
[
'cool_down_beginning_step'
]
self
.
regular_scale
:
int
|
None
=
init_kwargs
[
'regular_scale'
]
if
self
.
using_evaluator
else
None
self
.
movement_mode
:
Literal
[
'hard'
,
'soft'
]
|
None
=
init_kwargs
[
'movement_mode'
]
if
self
.
using_evaluator
else
None
self
.
sparse_granularity
=
init_kwargs
[
'sparse_granularity'
]
if
self
.
using_evaluator
else
None
assert
self
.
warm_up_step
<
self
.
cool_down_beginning_step
,
'`warm_up_step` should smaller than `cool_down_beginning_step`'
self
.
_model_parser
=
parser_factory
(
model
)
super
().
__init__
(
model
,
config_list
)
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
...
...
@@ -185,20 +227,61 @@ class MovementPruner(EvaluatorBasedPruner):
schema
.
validate
(
config_list
)
def
cubic_schedule
(
self
,
current_step
:
int
):
if
self
.
warm_up_step
<
current_step
<=
self
.
cool_down_beginning_step
:
wrapper_dict
=
self
.
get_modules_wrapper
()
for
config
in
self
.
config_list
:
scale
=
1
-
(
1
-
(
current_step
-
self
.
warm_up_step
)
/
(
self
.
cool_down_beginning_step
-
self
.
warm_up_step
))
**
3
current_sparsity
=
config
[
'total_sparsity'
]
*
scale
for
op_name
in
config
[
'op_names'
]:
wrapper
=
wrapper_dict
[
op_name
]
wrapper
.
config
[
'total_sparsity'
]
=
current_sparsity
wrapper_dict
=
self
.
get_modules_wrapper
()
for
config
in
self
.
config_list
:
current_sparsity
=
config
[
'total_sparsity'
]
*
self
.
_cubic_scale
(
current_step
)
for
op_name
in
config
[
'op_names'
]:
# There is an unreachable pyright error if `wrapper_dict[op_name].config['total_sparsity'] = current_sparsity`,
# seems a pyright bug...
wrapper_config
=
wrapper_dict
[
op_name
].
config
wrapper_config
[
'total_sparsity'
]
=
current_sparsity
def
_cubic_scale
(
self
,
current_step
:
int
):
if
self
.
warm_up_step
>
current_step
:
return
0
elif
current_step
>
self
.
cool_down_beginning_step
:
return
1
else
:
return
1
-
(
1
-
(
current_step
-
self
.
warm_up_step
)
/
(
self
.
cool_down_beginning_step
-
self
.
warm_up_step
))
**
3
def
_create_scalers
(
self
)
->
Scaling
|
Dict
[
str
,
Dict
[
str
,
Scaling
]]:
assert
self
.
bound_model
is
not
None
if
self
.
sparse_granularity
and
self
.
sparse_granularity
==
'auto'
and
self
.
_model_parser
:
scalers
=
{}
for
module_name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
if
self
.
_model_parser
.
is_attention
(
module_name
):
num_heads
=
self
.
_model_parser
.
get_num_heads
(
module_name
,
self
.
bound_model
)
if
num_heads
<=
0
:
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
1
])}
else
:
# assume attention layer weights are 2D
weight_h
:
int
=
wrapper
.
module
.
weight
.
shape
[
0
]
# type: ignore
weight_w
:
int
=
wrapper
.
module
.
weight
.
shape
[
1
]
# type: ignore
if
weight_h
%
num_heads
!=
0
or
weight_w
%
num_heads
!=
0
:
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
1
])}
else
:
block_h
=
weight_h
//
num_heads
block_w
=
weight_w
//
num_heads
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
block_h
,
block_w
])}
elif
self
.
_model_parser
.
is_ffn
(
module_name
,
ffn_num
=
1
):
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
1
,
wrapper
.
module
.
weight
.
shape
[
1
]])}
# type: ignore
elif
self
.
_model_parser
.
is_ffn
(
module_name
,
ffn_num
=
2
):
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
wrapper
.
module
.
weight
.
shape
[
0
],
1
])}
# type: ignore
else
:
scalers
[
module_name
]
=
{
'_default'
:
Scaling
([
1
])}
else
:
scalers
=
Scaling
([
1
])
return
scalers
def
reset_tools
(
self
):
scalers
=
self
.
_create_scalers
()
if
not
hasattr
(
self
,
'metrics_calculator'
):
self
.
metrics_calculator
=
StraightMetricsCalculator
()
if
not
hasattr
(
self
,
'sparsity_allocator'
):
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
continuous_mask
=
False
)
if
self
.
movement_mode
==
'soft'
:
self
.
sparsity_allocator
=
ThresholdSparsityAllocator
(
self
,
scalers
=
scalers
,
continuous_mask
=
False
)
else
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
,
scalers
=
scalers
,
continuous_mask
=
False
)
# use Adam to update the weight_score
assert
self
.
bound_model
is
not
None
...
...
@@ -206,6 +289,14 @@ class MovementPruner(EvaluatorBasedPruner):
optimizer
=
Adam
(
params
,
1e-2
)
self
.
step_counter
=
0
# TODO: waiting for api stable and experiemnts to prove this scheduler is needed.
# def lr_lambda(current_step: int):
# if current_step < self.warm_up_step:
# return float(current_step) / self.warm_up_step
# return max(0.0, float(147264 - current_step) / float(147264 - self.warm_up_step))
# lr_scheduler = LambdaLR(optimizer, lr_lambda)
# update the masks after each optimzier step
def
_optimizer_patch
():
optimizer
.
step
()
...
...
@@ -221,6 +312,17 @@ class MovementPruner(EvaluatorBasedPruner):
masks
=
self
.
sparsity_allocator
.
generate_sparsity
(
metrics
)
# type: ignore
self
.
load_masks
(
masks
)
def
_loss_patch
(
origin_loss
:
Tensor
):
if
self
.
regular_scale
is
not
None
:
l1_reg
=
0
count
=
0
for
wrapper
in
self
.
get_modules_wrapper
().
values
():
l1_reg
+=
torch
.
norm
(
torch
.
sigmoid
(
wrapper
.
weight_score
),
p
=
1
)
/
wrapper
.
weight_score
.
numel
()
# type: ignore
count
+=
1
return
origin_loss
+
self
.
regular_scale
*
self
.
_cubic_scale
(
self
.
step_counter
)
*
l1_reg
/
count
else
:
return
origin_loss
if
self
.
using_evaluator
:
# TODO: move to other place in nni v3.0
self
.
evaluator
.
unbind_model
()
...
...
@@ -228,7 +330,9 @@ class MovementPruner(EvaluatorBasedPruner):
if
not
hasattr
(
self
,
'data_collector'
):
self
.
data_collector
=
EvaluatorBasedScoreDataCollector
(
self
,
self
.
evaluator
,
after_opt_step_tasks
=
[
_optimizer_patch
],
max_epochs
=
self
.
training_epochs
)
max_epochs
=
self
.
training_epochs
,
max_steps
=
self
.
training_steps
,
loss_patch
=
_loss_patch
)
else
:
self
.
data_collector
.
reset
(
after_opt_step_tasks
=
[
_optimizer_patch
])
else
:
...
...
@@ -252,7 +356,27 @@ class MovementPruner(EvaluatorBasedPruner):
The configuration for generating the mask.
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
)
assert
self
.
bound_model
is
not
None
# TODO: merge with _create_scalers after nni v3.0
if
self
.
sparse_granularity
and
self
.
sparse_granularity
==
'auto'
and
self
.
_model_parser
:
if
self
.
_model_parser
.
is_attention
(
layer
.
name
):
num_heads
=
self
.
_model_parser
.
get_num_heads
(
layer
.
name
,
self
.
bound_model
)
if
num_heads
<=
0
:
score_size
=
None
else
:
if
layer
.
module
.
weight
.
shape
[
0
]
%
num_heads
!=
0
or
layer
.
module
.
weight
.
shape
[
1
]
%
num_heads
!=
0
:
# type: ignore
score_size
=
None
else
:
score_size
=
[
num_heads
,
num_heads
]
elif
self
.
_model_parser
.
is_ffn
(
layer
.
name
,
ffn_num
=
1
):
score_size
=
[
layer
.
module
.
weight
.
shape
[
0
],
1
]
# type: ignore
elif
self
.
_model_parser
.
is_ffn
(
layer
.
name
,
ffn_num
=
2
):
score_size
=
[
1
,
layer
.
module
.
weight
.
shape
[
1
]]
# type: ignore
else
:
score_size
=
None
else
:
score_size
=
None
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
score_size
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
# type: ignore
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
View file @
b2c31ca2
...
...
@@ -29,6 +29,7 @@ from .metrics_calculator import (
)
from
.sparsity_allocator
import
(
NormalSparsityAllocator
,
ThresholdSparsityAllocator
,
BankSparsityAllocator
,
GlobalSparsityAllocator
,
DependencyAwareAllocator
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment