"examples/vscode:/vscode.git/clone" did not exist on "e3a80b70a2be08b531385cd146e1627cbd718fbe"
Commit e773dfcc authored by qianyj's avatar qianyj
Browse files

create branch for v2.9

parents
:orphan:
.. _sphx_glr_tutorials:
Tutorials
=========
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Introduction ------------">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt: Speedup Model with Mask
:ref:`sphx_glr_tutorials_pruning_speedup.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/pruning_speedup
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt: SpeedUp Model with Calibration Config
:ref:`sphx_glr_tutorials_quantization_speedup.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/quantization_speedup
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt: Quantization Quickstart
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/quantization_quick_start_mnist
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt: Pruning Quickstart
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/pruning_quick_start_mnist
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt: Customize a new quantization algorithm
:ref:`sphx_glr_tutorials_quantization_customize.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/quantization_customize
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt: Use NAS Benchmarks as Datasets
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/nasbench_as_dataset
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
:alt: Customize Basic Pruner
:ref:`sphx_glr_tutorials_pruning_customize.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/pruning_customize
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt: Hello, NAS!
:ref:`sphx_glr_tutorials_hello_nas.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hello_nas
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we demonstrate how to search in the famous model space proposed in `DARTS`_.">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
:alt: Searching in DARTS search space
:ref:`sphx_glr_tutorials_darts.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/darts
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
.. only:: html
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Pruning Bert on Task MNLI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/pruning_bert_glue
.. raw:: html
<div class="sphx-glr-clear"></div>
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/main
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/model
.. raw:: html
<div class="sphx-glr-clear"></div>
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
.. only:: html
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/main
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
.. only:: html
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_tensorflow/model
.. raw:: html
<div class="sphx-glr-clear"></div>
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Use NAS Benchmarks as Datasets\n\nIn this tutorial, we show how to use NAS Benchmarks as datasets.\nFor research purposes we sometimes desire to query the benchmarks for architecture accuracies,\nrather than train them one by one from scratch.\nNNI has provided query tools so that users can easily get the retrieve the data in NAS benchmarks.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites\nThis tutorial assumes that you have already prepared your NAS benchmarks under cache directory\n(by default, ``~/.cache/nni/nasbenchmark``).\nIf you haven't, please follow the data preparation guide in :doc:`/nas/benchmarks`.\n\nAs a result, the directory should look like:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\nos.listdir(os.path.expanduser('~/.cache/nni/nasbenchmark'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import pprint\n\nfrom nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats\nfrom nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats\nfrom nni.nas.benchmarks.nds import query_nds_trial_stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NAS-Bench-101\n\nUse the following architecture as an example:\n\n<img src=\"file://../../img/nas-bench-101-example.png\">\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"arch = {\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu',\n 'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5]\n}\nfor t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An architecture of NAS-Bench-101 could be trained more than once.\nEach element of the returned generator is a dict which contains one of the training results of this trial config\n(architecture + hyper-parameters) including train/valid/test accuracy,\ntraining time, number of epochs, etc. The results of NAS-Bench-201 and NDS follow similar formats.\n\n## NAS-Bench-201\n\nUse the following architecture as an example:\n\n<img src=\"file://../../img/nas-bench-201-example.png\">\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"arch = {\n '0_1': 'avg_pool_3x3',\n '0_2': 'conv_1x1',\n '1_2': 'skip_connect',\n '0_3': 'conv_1x1',\n '1_3': 'skip_connect',\n '2_3': 'skip_connect'\n}\nfor t in query_nb201_trial_stats(arch, 200, 'cifar100'):\n pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Intermediate results are also available.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n print(t['config'])\n print('Intermediates:', len(t['intermediates']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NDS\n\nUse the following architecture as an example:\n\n<img src=\"file://../../img/nas-bench-nds-example.png\">\n\nHere, ``bot_muls``, ``ds``, ``num_gs``, ``ss`` and ``ws`` stand for \"bottleneck multipliers\",\n\"depths\", \"number of groups\", \"strides\" and \"widths\" respectively.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_spec = {\n 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n 'ds': [1, 16, 1, 4],\n 'num_gs': [1, 2, 1, 2],\n 'ss': [1, 1, 2, 2],\n 'ws': [16, 64, 128, 16]\n}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use none as a wildcard.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):\n pprint.pprint(t)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_spec = {\n 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n 'ds': [1, 16, 1, 4],\n 'num_gs': [1, 2, 1, 2],\n 'ss': [1, 1, 2, 2],\n 'ws': [16, 64, 128, 16]\n}\nfor t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n pprint.pprint(t['intermediates'][:10])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}\nfor t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):\n pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the first one.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count number.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}\ncell_spec = {\n 'normal_0_op_x': 'avg_pool_3x3',\n 'normal_0_input_x': 0,\n 'normal_0_op_y': 'conv_7x1_1x7',\n 'normal_0_input_y': 1,\n 'normal_1_op_x': 'sep_conv_3x3',\n 'normal_1_input_x': 2,\n 'normal_1_op_y': 'sep_conv_5x5',\n 'normal_1_input_y': 0,\n 'normal_2_op_x': 'dil_sep_conv_3x3',\n 'normal_2_input_x': 2,\n 'normal_2_op_y': 'dil_sep_conv_3x3',\n 'normal_2_input_y': 2,\n 'normal_3_op_x': 'skip_connect',\n 'normal_3_input_x': 4,\n 'normal_3_op_y': 'dil_sep_conv_3x3',\n 'normal_3_input_y': 4,\n 'normal_4_op_x': 'conv_7x1_1x7',\n 'normal_4_input_x': 2,\n 'normal_4_op_y': 'sep_conv_3x3',\n 'normal_4_input_y': 4,\n 'normal_concat': [3, 5, 6],\n 'reduce_0_op_x': 'avg_pool_3x3',\n 'reduce_0_input_x': 0,\n 'reduce_0_op_y': 'dil_sep_conv_3x3',\n 'reduce_0_input_y': 1,\n 'reduce_1_op_x': 'sep_conv_3x3',\n 'reduce_1_input_x': 0,\n 'reduce_1_op_y': 'sep_conv_3x3',\n 'reduce_1_input_y': 0,\n 'reduce_2_op_x': 'skip_connect',\n 'reduce_2_input_x': 2,\n 'reduce_2_op_y': 'sep_conv_7x7',\n 'reduce_2_input_y': 0,\n 'reduce_3_op_x': 'conv_7x1_1x7',\n 'reduce_3_input_x': 4,\n 'reduce_3_op_y': 'skip_connect',\n 'reduce_3_input_y': 4,\n 'reduce_4_op_x': 'conv_7x1_1x7',\n 'reduce_4_input_x': 0,\n 'reduce_4_op_y': 'conv_7x1_1x7',\n 'reduce_4_input_y': 5,\n 'reduce_concat': [3, 6]\n}\n\nfor t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):\n assert t['config']['model_spec'] == model_spec\n assert t['config']['cell_spec'] == cell_spec\n pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count number.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
\ No newline at end of file
"""
Use NAS Benchmarks as Datasets
==============================
In this tutorial, we show how to use NAS Benchmarks as datasets.
For research purposes we sometimes desire to query the benchmarks for architecture accuracies,
rather than train them one by one from scratch.
NNI has provided query tools so that users can easily get the retrieve the data in NAS benchmarks.
"""
# %%
# Prerequisites
# -------------
# This tutorial assumes that you have already prepared your NAS benchmarks under cache directory
# (by default, ``~/.cache/nni/nasbenchmark``).
# If you haven't, please follow the data preparation guide in :doc:`/nas/benchmarks`.
#
# As a result, the directory should look like:
import os
os.listdir(os.path.expanduser('~/.cache/nni/nasbenchmark'))
# %%
import pprint
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
from nni.nas.benchmarks.nds import query_nds_trial_stats
# %%
# NAS-Bench-101
# -------------
#
# Use the following architecture as an example:
#
# .. image:: ../../img/nas-bench-101-example.png
arch = {
'op1': 'conv3x3-bn-relu',
'op2': 'maxpool3x3',
'op3': 'conv3x3-bn-relu',
'op4': 'conv3x3-bn-relu',
'op5': 'conv1x1-bn-relu',
'input1': [0],
'input2': [1],
'input3': [2],
'input4': [0],
'input5': [0, 3, 4],
'input6': [2, 5]
}
for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):
pprint.pprint(t)
# %%
# An architecture of NAS-Bench-101 could be trained more than once.
# Each element of the returned generator is a dict which contains one of the training results of this trial config
# (architecture + hyper-parameters) including train/valid/test accuracy,
# training time, number of epochs, etc. The results of NAS-Bench-201 and NDS follow similar formats.
#
# NAS-Bench-201
# -------------
#
# Use the following architecture as an example:
#
# .. image:: ../../img/nas-bench-201-example.png
arch = {
'0_1': 'avg_pool_3x3',
'0_2': 'conv_1x1',
'1_2': 'skip_connect',
'0_3': 'conv_1x1',
'1_3': 'skip_connect',
'2_3': 'skip_connect'
}
for t in query_nb201_trial_stats(arch, 200, 'cifar100'):
pprint.pprint(t)
# %%
# Intermediate results are also available.
for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):
print(t['config'])
print('Intermediates:', len(t['intermediates']))
# %%
# NDS
# ---
#
# Use the following architecture as an example:
#
# .. image:: ../../img/nas-bench-nds-example.png
#
# Here, ``bot_muls``, ``ds``, ``num_gs``, ``ss`` and ``ws`` stand for "bottleneck multipliers",
# "depths", "number of groups", "strides" and "widths" respectively.
# %%
model_spec = {
'bot_muls': [0.0, 0.25, 0.25, 0.25],
'ds': [1, 16, 1, 4],
'num_gs': [1, 2, 1, 2],
'ss': [1, 1, 2, 2],
'ws': [16, 64, 128, 16]
}
# %%
# Use none as a wildcard.
for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):
pprint.pprint(t)
# %%
model_spec = {
'bot_muls': [0.0, 0.25, 0.25, 0.25],
'ds': [1, 16, 1, 4],
'num_gs': [1, 2, 1, 2],
'ss': [1, 1, 2, 2],
'ws': [16, 64, 128, 16]
}
for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):
pprint.pprint(t['intermediates'][:10])
# %%
model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}
for t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):
pprint.pprint(t)
# %%
# Get the first one.
pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))
# %%
# Count number.
model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}
cell_spec = {
'normal_0_op_x': 'avg_pool_3x3',
'normal_0_input_x': 0,
'normal_0_op_y': 'conv_7x1_1x7',
'normal_0_input_y': 1,
'normal_1_op_x': 'sep_conv_3x3',
'normal_1_input_x': 2,
'normal_1_op_y': 'sep_conv_5x5',
'normal_1_input_y': 0,
'normal_2_op_x': 'dil_sep_conv_3x3',
'normal_2_input_x': 2,
'normal_2_op_y': 'dil_sep_conv_3x3',
'normal_2_input_y': 2,
'normal_3_op_x': 'skip_connect',
'normal_3_input_x': 4,
'normal_3_op_y': 'dil_sep_conv_3x3',
'normal_3_input_y': 4,
'normal_4_op_x': 'conv_7x1_1x7',
'normal_4_input_x': 2,
'normal_4_op_y': 'sep_conv_3x3',
'normal_4_input_y': 4,
'normal_concat': [3, 5, 6],
'reduce_0_op_x': 'avg_pool_3x3',
'reduce_0_input_x': 0,
'reduce_0_op_y': 'dil_sep_conv_3x3',
'reduce_0_input_y': 1,
'reduce_1_op_x': 'sep_conv_3x3',
'reduce_1_input_x': 0,
'reduce_1_op_y': 'sep_conv_3x3',
'reduce_1_input_y': 0,
'reduce_2_op_x': 'skip_connect',
'reduce_2_input_x': 2,
'reduce_2_op_y': 'sep_conv_7x7',
'reduce_2_input_y': 0,
'reduce_3_op_x': 'conv_7x1_1x7',
'reduce_3_input_x': 4,
'reduce_3_op_y': 'skip_connect',
'reduce_3_input_y': 4,
'reduce_4_op_x': 'conv_7x1_1x7',
'reduce_4_input_x': 0,
'reduce_4_op_y': 'conv_7x1_1x7',
'reduce_4_input_y': 5,
'reduce_concat': [3, 6]
}
for t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):
assert t['config']['model_spec'] == model_spec
assert t['config']['cell_spec'] == cell_spec
pprint.pprint(t)
# %%
# Count number.
print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))
715de24d20c57f3639033f6f10376c21
\ No newline at end of file
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/nasbench_as_dataset.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_tutorials_nasbench_as_dataset.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_nasbench_as_dataset.py:
Use NAS Benchmarks as Datasets
==============================
In this tutorial, we show how to use NAS Benchmarks as datasets.
For research purposes we sometimes desire to query the benchmarks for architecture accuracies,
rather than train them one by one from scratch.
NNI has provided query tools so that users can easily get the retrieve the data in NAS benchmarks.
.. GENERATED FROM PYTHON SOURCE LINES 12-19
Prerequisites
-------------
This tutorial assumes that you have already prepared your NAS benchmarks under cache directory
(by default, ``~/.cache/nni/nasbenchmark``).
If you haven't, please follow the data preparation guide in :doc:`/nas/benchmarks`.
As a result, the directory should look like:
.. GENERATED FROM PYTHON SOURCE LINES 19-23
.. code-block:: default
import os
os.listdir(os.path.expanduser('~/.cache/nni/nasbenchmark'))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
['nasbench101-209f5694.db', 'nasbench201-b2b60732.db', 'nds-5745c235.db']
.. GENERATED FROM PYTHON SOURCE LINES 24-30
.. code-block:: default
import pprint
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
from nni.nas.benchmarks.nds import query_nds_trial_stats
.. GENERATED FROM PYTHON SOURCE LINES 31-37
NAS-Bench-101
-------------
Use the following architecture as an example:
.. image:: ../../img/nas-bench-101-example.png
.. GENERATED FROM PYTHON SOURCE LINES 37-54
.. code-block:: default
arch = {
'op1': 'conv3x3-bn-relu',
'op2': 'maxpool3x3',
'op3': 'conv3x3-bn-relu',
'op4': 'conv3x3-bn-relu',
'op5': 'conv1x1-bn-relu',
'input1': [0],
'input2': [1],
'input3': [2],
'input4': [0],
'input5': [0, 3, 4],
'input6': [2, 5]
}
for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):
pprint.pprint(t)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
[2022-02-28 13:48:51] INFO (nni.nas.benchmarks.utils/MainThread) "/home/yugzhan/.cache/nni/nasbenchmark/nasbench101-209f5694.db" already exists. Checking hash.
{'config': {'arch': {'input1': [0],
'input2': [1],
'input3': [2],
'input4': [0],
'input5': [0, 3, 4],
'input6': [2, 5],
'op1': 'conv3x3-bn-relu',
'op2': 'maxpool3x3',
'op3': 'conv3x3-bn-relu',
'op4': 'conv3x3-bn-relu',
'op5': 'conv1x1-bn-relu'},
'hash': '00005c142e6f48ac74fdcf73e3439874',
'id': 4,
'num_epochs': 108,
'num_vertices': 7},
'id': 10,
'intermediates': [{'current_epoch': 54,
'id': 19,
'test_acc': 77.40384340286255,
'train_acc': 82.82251358032227,
'training_time': 883.4580078125,
'valid_acc': 77.76442170143127},
{'current_epoch': 108,
'id': 20,
'test_acc': 92.11738705635071,
'train_acc': 100.0,
'training_time': 1769.1279296875,
'valid_acc': 92.41786599159241}],
'parameters': 8.55553,
'test_acc': 92.11738705635071,
'train_acc': 100.0,
'training_time': 106147.67578125,
'valid_acc': 92.41786599159241}
{'config': {'arch': {'input1': [0],
'input2': [1],
'input3': [2],
'input4': [0],
'input5': [0, 3, 4],
'input6': [2, 5],
'op1': 'conv3x3-bn-relu',
'op2': 'maxpool3x3',
'op3': 'conv3x3-bn-relu',
'op4': 'conv3x3-bn-relu',
'op5': 'conv1x1-bn-relu'},
'hash': '00005c142e6f48ac74fdcf73e3439874',
'id': 4,
'num_epochs': 108,
'num_vertices': 7},
'id': 11,
'intermediates': [{'current_epoch': 54,
'id': 21,
'test_acc': 82.04126358032227,
'train_acc': 87.96073794364929,
'training_time': 883.6810302734375,
'valid_acc': 82.91265964508057},
{'current_epoch': 108,
'id': 22,
'test_acc': 91.90705418586731,
'train_acc': 100.0,
'training_time': 1768.2509765625,
'valid_acc': 92.45793223381042}],
'parameters': 8.55553,
'test_acc': 91.90705418586731,
'train_acc': 100.0,
'training_time': 106095.05859375,
'valid_acc': 92.45793223381042}
{'config': {'arch': {'input1': [0],
'input2': [1],
'input3': [2],
'input4': [0],
'input5': [0, 3, 4],
'input6': [2, 5],
'op1': 'conv3x3-bn-relu',
'op2': 'maxpool3x3',
'op3': 'conv3x3-bn-relu',
'op4': 'conv3x3-bn-relu',
'op5': 'conv1x1-bn-relu'},
'hash': '00005c142e6f48ac74fdcf73e3439874',
'id': 4,
'num_epochs': 108,
'num_vertices': 7},
'id': 12,
'intermediates': [{'current_epoch': 54,
'id': 23,
'test_acc': 80.58894276618958,
'train_acc': 86.34815812110901,
'training_time': 883.4569702148438,
'valid_acc': 81.1598539352417},
{'current_epoch': 108,
'id': 24,
'test_acc': 92.15745329856873,
'train_acc': 100.0,
'training_time': 1768.9759521484375,
'valid_acc': 93.04887652397156}],
'parameters': 8.55553,
'test_acc': 92.15745329856873,
'train_acc': 100.0,
'training_time': 106138.55712890625,
'valid_acc': 93.04887652397156}
.. GENERATED FROM PYTHON SOURCE LINES 55-66
An architecture of NAS-Bench-101 could be trained more than once.
Each element of the returned generator is a dict which contains one of the training results of this trial config
(architecture + hyper-parameters) including train/valid/test accuracy,
training time, number of epochs, etc. The results of NAS-Bench-201 and NDS follow similar formats.
NAS-Bench-201
-------------
Use the following architecture as an example:
.. image:: ../../img/nas-bench-201-example.png
.. GENERATED FROM PYTHON SOURCE LINES 66-78
.. code-block:: default
arch = {
'0_1': 'avg_pool_3x3',
'0_2': 'conv_1x1',
'1_2': 'skip_connect',
'0_3': 'conv_1x1',
'1_3': 'skip_connect',
'2_3': 'skip_connect'
}
for t in query_nb201_trial_stats(arch, 200, 'cifar100'):
pprint.pprint(t)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
[2022-02-28 13:49:09] INFO (nni.nas.benchmarks.utils/MainThread) "/home/yugzhan/.cache/nni/nasbenchmark/nasbench201-b2b60732.db" already exists. Checking hash.
{'config': {'arch': {'0_1': 'avg_pool_3x3',
'0_2': 'conv_1x1',
'0_3': 'conv_1x1',
'1_2': 'skip_connect',
'1_3': 'skip_connect',
'2_3': 'skip_connect'},
'dataset': 'cifar100',
'id': 7,
'num_cells': 5,
'num_channels': 16,
'num_epochs': 200},
'flops': 15.65322,
'id': 3,
'latency': 0.013182918230692545,
'ori_test_acc': 53.11,
'ori_test_evaluation_time': 1.0195916947864352,
'ori_test_loss': 1.7307863704681397,
'parameters': 0.135156,
'seed': 999,
'test_acc': 53.07999995727539,
'test_evaluation_time': 0.5097958473932176,
'test_loss': 1.731276072692871,
'train_acc': 57.82,
'train_loss': 1.5116578379058838,
'training_time': 2888.4371995925903,
'valid_acc': 53.14000000610351,
'valid_evaluation_time': 0.5097958473932176,
'valid_loss': 1.7302966793060304}
{'config': {'arch': {'0_1': 'avg_pool_3x3',
'0_2': 'conv_1x1',
'0_3': 'conv_1x1',
'1_2': 'skip_connect',
'1_3': 'skip_connect',
'2_3': 'skip_connect'},
'dataset': 'cifar100',
'id': 7,
'num_cells': 5,
'num_channels': 16,
'num_epochs': 200},
'flops': 15.65322,
'id': 7,
'latency': 0.013182918230692545,
'ori_test_acc': 51.93,
'ori_test_evaluation_time': 1.0195916947864352,
'ori_test_loss': 1.7572312774658203,
'parameters': 0.135156,
'seed': 777,
'test_acc': 51.979999938964845,
'test_evaluation_time': 0.5097958473932176,
'test_loss': 1.7429540189743042,
'train_acc': 57.578,
'train_loss': 1.5114233912658692,
'training_time': 2888.4371995925903,
'valid_acc': 51.88,
'valid_evaluation_time': 0.5097958473932176,
'valid_loss': 1.7715086591720581}
{'config': {'arch': {'0_1': 'avg_pool_3x3',
'0_2': 'conv_1x1',
'0_3': 'conv_1x1',
'1_2': 'skip_connect',
'1_3': 'skip_connect',
'2_3': 'skip_connect'},
'dataset': 'cifar100',
'id': 7,
'num_cells': 5,
'num_channels': 16,
'num_epochs': 200},
'flops': 15.65322,
'id': 11,
'latency': 0.013182918230692545,
'ori_test_acc': 53.38,
'ori_test_evaluation_time': 1.0195916947864352,
'ori_test_loss': 1.7281623031616211,
'parameters': 0.135156,
'seed': 888,
'test_acc': 53.67999998779297,
'test_evaluation_time': 0.5097958473932176,
'test_loss': 1.7327697801589965,
'train_acc': 57.792,
'train_loss': 1.5091403088760376,
'training_time': 2888.4371995925903,
'valid_acc': 53.08000000610352,
'valid_evaluation_time': 0.5097958473932176,
'valid_loss': 1.7235548280715942}
.. GENERATED FROM PYTHON SOURCE LINES 79-80
Intermediate results are also available.
.. GENERATED FROM PYTHON SOURCE LINES 80-85
.. code-block:: default
for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):
print(t['config'])
print('Intermediates:', len(t['intermediates']))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}
Intermediates: 12
{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}
Intermediates: 200
{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}
Intermediates: 200
{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}
Intermediates: 200
.. GENERATED FROM PYTHON SOURCE LINES 86-95
NDS
---
Use the following architecture as an example:
.. image:: ../../img/nas-bench-nds-example.png
Here, ``bot_muls``, ``ds``, ``num_gs``, ``ss`` and ``ws`` stand for "bottleneck multipliers",
"depths", "number of groups", "strides" and "widths" respectively.
.. GENERATED FROM PYTHON SOURCE LINES 97-105
.. code-block:: default
model_spec = {
'bot_muls': [0.0, 0.25, 0.25, 0.25],
'ds': [1, 16, 1, 4],
'num_gs': [1, 2, 1, 2],
'ss': [1, 1, 2, 2],
'ws': [16, 64, 128, 16]
}
.. GENERATED FROM PYTHON SOURCE LINES 106-107
Use none as a wildcard.
.. GENERATED FROM PYTHON SOURCE LINES 107-110
.. code-block:: default
for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):
pprint.pprint(t)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
[2022-02-28 13:49:36] INFO (nni.nas.benchmarks.utils/MainThread) "/home/yugzhan/.cache/nni/nasbenchmark/nds-5745c235.db" already exists. Checking hash.
{'best_test_acc': 90.48,
'best_train_acc': 96.356,
'best_train_loss': 0.116,
'config': {'base_lr': 0.1,
'cell_spec': {},
'dataset': 'cifar10',
'generator': 'random',
'id': 45505,
'model_family': 'residual_bottleneck',
'model_spec': {'bot_muls': [0.0, 0.25, 0.25, 0.25],
'ds': [1, 16, 1, 4],
'num_gs': [1, 2, 1, 2],
'ss': [1, 1, 2, 2],
'ws': [16, 64, 128, 16]},
'num_epochs': 100,
'proposer': 'resnext-a',
'weight_decay': 0.0005},
'final_test_acc': 90.39,
'final_train_acc': 96.298,
'final_train_loss': 0.116,
'flops': 69.890986,
'id': 45505,
'iter_time': 0.065,
'parameters': 0.083002,
'seed': 1}
.. GENERATED FROM PYTHON SOURCE LINES 111-121
.. code-block:: default
model_spec = {
'bot_muls': [0.0, 0.25, 0.25, 0.25],
'ds': [1, 16, 1, 4],
'num_gs': [1, 2, 1, 2],
'ss': [1, 1, 2, 2],
'ws': [16, 64, 128, 16]
}
for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):
pprint.pprint(t['intermediates'][:10])
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
[{'current_epoch': 1,
'id': 4494501,
'test_acc': 41.76,
'train_acc': 30.421000000000006,
'train_loss': 1.793},
{'current_epoch': 2,
'id': 4494502,
'test_acc': 54.66,
'train_acc': 47.24,
'train_loss': 1.415},
{'current_epoch': 3,
'id': 4494503,
'test_acc': 59.97,
'train_acc': 56.983,
'train_loss': 1.179},
{'current_epoch': 4,
'id': 4494504,
'test_acc': 62.91,
'train_acc': 61.955,
'train_loss': 1.048},
{'current_epoch': 5,
'id': 4494505,
'test_acc': 66.16,
'train_acc': 64.493,
'train_loss': 0.983},
{'current_epoch': 6,
'id': 4494506,
'test_acc': 66.5,
'train_acc': 66.274,
'train_loss': 0.937},
{'current_epoch': 7,
'id': 4494507,
'test_acc': 67.55,
'train_acc': 67.426,
'train_loss': 0.907},
{'current_epoch': 8,
'id': 4494508,
'test_acc': 69.45,
'train_acc': 68.45400000000001,
'train_loss': 0.878},
{'current_epoch': 9,
'id': 4494509,
'test_acc': 70.14,
'train_acc': 69.295,
'train_loss': 0.857},
{'current_epoch': 10,
'id': 4494510,
'test_acc': 69.47,
'train_acc': 70.304,
'train_loss': 0.832}]
.. GENERATED FROM PYTHON SOURCE LINES 122-126
.. code-block:: default
model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}
for t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):
pprint.pprint(t)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
{'best_test_acc': 93.58,
'best_train_acc': 99.772,
'best_train_loss': 0.011,
'config': {'base_lr': 0.1,
'cell_spec': {},
'dataset': 'cifar10',
'generator': 'random',
'id': 108998,
'model_family': 'residual_basic',
'model_spec': {'ds': [1, 12, 12, 12],
'ss': [1, 1, 2, 2],
'ws': [16, 24, 24, 40]},
'num_epochs': 100,
'proposer': 'resnet',
'weight_decay': 0.0005},
'final_test_acc': 93.49,
'final_train_acc': 99.772,
'final_train_loss': 0.011,
'flops': 184.519578,
'id': 108998,
'iter_time': 0.059,
'parameters': 0.594138,
'seed': 1}
.. GENERATED FROM PYTHON SOURCE LINES 127-128
Get the first one.
.. GENERATED FROM PYTHON SOURCE LINES 128-130
.. code-block:: default
pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
{'best_test_acc': 84.5,
'best_train_acc': 89.66499999999999,
'best_train_loss': 0.302,
'config': {'base_lr': 0.1,
'cell_spec': {},
'dataset': 'cifar10',
'generator': 'random',
'id': 139492,
'model_family': 'vanilla',
'model_spec': {'ds': [1, 12, 12, 12],
'ss': [1, 1, 2, 2],
'ws': [16, 24, 32, 40]},
'num_epochs': 100,
'proposer': 'vanilla',
'weight_decay': 0.0005},
'final_test_acc': 84.35,
'final_train_acc': 89.633,
'final_train_loss': 0.303,
'flops': 208.36393,
'id': 154692,
'iter_time': 0.058,
'parameters': 0.68977,
'seed': 1}
.. GENERATED FROM PYTHON SOURCE LINES 131-132
Count number.
.. GENERATED FROM PYTHON SOURCE LINES 132-183
.. code-block:: default
model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}
cell_spec = {
'normal_0_op_x': 'avg_pool_3x3',
'normal_0_input_x': 0,
'normal_0_op_y': 'conv_7x1_1x7',
'normal_0_input_y': 1,
'normal_1_op_x': 'sep_conv_3x3',
'normal_1_input_x': 2,
'normal_1_op_y': 'sep_conv_5x5',
'normal_1_input_y': 0,
'normal_2_op_x': 'dil_sep_conv_3x3',
'normal_2_input_x': 2,
'normal_2_op_y': 'dil_sep_conv_3x3',
'normal_2_input_y': 2,
'normal_3_op_x': 'skip_connect',
'normal_3_input_x': 4,
'normal_3_op_y': 'dil_sep_conv_3x3',
'normal_3_input_y': 4,
'normal_4_op_x': 'conv_7x1_1x7',
'normal_4_input_x': 2,
'normal_4_op_y': 'sep_conv_3x3',
'normal_4_input_y': 4,
'normal_concat': [3, 5, 6],
'reduce_0_op_x': 'avg_pool_3x3',
'reduce_0_input_x': 0,
'reduce_0_op_y': 'dil_sep_conv_3x3',
'reduce_0_input_y': 1,
'reduce_1_op_x': 'sep_conv_3x3',
'reduce_1_input_x': 0,
'reduce_1_op_y': 'sep_conv_3x3',
'reduce_1_input_y': 0,
'reduce_2_op_x': 'skip_connect',
'reduce_2_input_x': 2,
'reduce_2_op_y': 'sep_conv_7x7',
'reduce_2_input_y': 0,
'reduce_3_op_x': 'conv_7x1_1x7',
'reduce_3_input_x': 4,
'reduce_3_op_y': 'skip_connect',
'reduce_3_input_y': 4,
'reduce_4_op_x': 'conv_7x1_1x7',
'reduce_4_input_x': 0,
'reduce_4_op_y': 'conv_7x1_1x7',
'reduce_4_input_y': 5,
'reduce_concat': [3, 6]
}
for t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):
assert t['config']['model_spec'] == model_spec
assert t['config']['cell_spec'] == cell_spec
pprint.pprint(t)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
{'best_test_acc': 93.37,
'best_train_acc': 99.91,
'best_train_loss': 0.006,
'config': {'base_lr': 0.1,
'cell_spec': {'normal_0_input_x': 0,
'normal_0_input_y': 1,
'normal_0_op_x': 'avg_pool_3x3',
'normal_0_op_y': 'conv_7x1_1x7',
'normal_1_input_x': 2,
'normal_1_input_y': 0,
'normal_1_op_x': 'sep_conv_3x3',
'normal_1_op_y': 'sep_conv_5x5',
'normal_2_input_x': 2,
'normal_2_input_y': 2,
'normal_2_op_x': 'dil_sep_conv_3x3',
'normal_2_op_y': 'dil_sep_conv_3x3',
'normal_3_input_x': 4,
'normal_3_input_y': 4,
'normal_3_op_x': 'skip_connect',
'normal_3_op_y': 'dil_sep_conv_3x3',
'normal_4_input_x': 2,
'normal_4_input_y': 4,
'normal_4_op_x': 'conv_7x1_1x7',
'normal_4_op_y': 'sep_conv_3x3',
'normal_concat': [3, 5, 6],
'reduce_0_input_x': 0,
'reduce_0_input_y': 1,
'reduce_0_op_x': 'avg_pool_3x3',
'reduce_0_op_y': 'dil_sep_conv_3x3',
'reduce_1_input_x': 0,
'reduce_1_input_y': 0,
'reduce_1_op_x': 'sep_conv_3x3',
'reduce_1_op_y': 'sep_conv_3x3',
'reduce_2_input_x': 2,
'reduce_2_input_y': 0,
'reduce_2_op_x': 'skip_connect',
'reduce_2_op_y': 'sep_conv_7x7',
'reduce_3_input_x': 4,
'reduce_3_input_y': 4,
'reduce_3_op_x': 'conv_7x1_1x7',
'reduce_3_op_y': 'skip_connect',
'reduce_4_input_x': 0,
'reduce_4_input_y': 5,
'reduce_4_op_x': 'conv_7x1_1x7',
'reduce_4_op_y': 'conv_7x1_1x7',
'reduce_concat': [3, 6]},
'dataset': 'cifar10',
'generator': 'random',
'id': 1,
'model_family': 'nas_cell',
'model_spec': {'aux': False,
'depth': 12,
'drop_prob': 0.0,
'num_nodes_normal': 5,
'num_nodes_reduce': 5,
'width': 32},
'num_epochs': 100,
'proposer': 'amoeba',
'weight_decay': 0.0005},
'final_test_acc': 93.27,
'final_train_acc': 99.91,
'final_train_loss': 0.006,
'flops': 664.400586,
'id': 1,
'iter_time': 0.281,
'parameters': 4.190314,
'seed': 1}
.. GENERATED FROM PYTHON SOURCE LINES 184-185
Count number.
.. GENERATED FROM PYTHON SOURCE LINES 185-186
.. code-block:: default
print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
NDS (amoeba) count: 5107
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 2.214 seconds)
.. _sphx_glr_download_tutorials_nasbench_as_dataset.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: nasbench_as_dataset.py <nasbench_as_dataset.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: nasbench_as_dataset.ipynb <nasbench_as_dataset.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Pruning Bert on Task MNLI\n\n## Workable Pruning Process\n\nHere we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.\n\nThe entire pruning process can be divided into the following steps:\n\n1. Finetune the pre-trained model on the downstream task. From our experience,\n the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.\n At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following\n distillation training.\n2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,\n and directly prune the head (condense the weight) if the head was fully masked.\n If the head was partially masked, we will not prune it and recover its weight.\n3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.\n4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,\n and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.\n5. Retrain the final pruned model with distillation.\n\nDuring the process of pruning transformer, we gained some of the following experiences:\n\n* We using `movement-pruner` in step 2 and `taylor-fo-weight-pruner` in step 4. `movement-pruner` has good performance on attention layers,\n and `taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,\n we also try weight-based pruning algorithms like `l1-norm-pruner`, but it doesn't seem to work well in this scenario.\n* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.\n* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.\n\n## Experiment\n\nThe complete pruning process will take about 8 hours on one A100.\n\n### Preparation\n\nThis section is mainly to get a finetuned model on the downstream task.\nIf you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.</p></div>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dev_mode = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some basic setting.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pathlib import Path\nfrom typing import Callable, Dict\n\npretrained_model_name_or_path = 'bert-base-uncased'\ntask_name = 'mnli'\nexperiment_id = 'pruning_bert_mnli'\n\n# heads_num and layers_num should align with pretrained_model_name_or_path\nheads_num = 12\nlayers_num = 12\n\n# used to save the experiment log\nlog_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')\nlog_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name\nmodel_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')\nmodel_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save GLUE data\ndata_dir = Path(f'./data')\ndata_dir.mkdir(parents=True, exist_ok=True)\n\n# set seed\nfrom transformers import set_seed\nset_seed(1024)\n\nimport torch\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create dataloaders.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n\nfrom datasets import load_dataset\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding\n\ntask_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n}\n\ndef prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):\n tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\n sentence1_key, sentence2_key = task_to_keys[task_name]\n data_collator = DataCollatorWithPadding(tokenizer)\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n train_dataset = processed_datasets['train']\n if task_name == 'mnli':\n validation_datasets = {\n 'validation_matched': processed_datasets['validation_matched'],\n 'validation_mismatched': processed_datasets['validation_mismatched']\n }\n else:\n validation_datasets = {\n 'validation': processed_datasets['validation']\n }\n\n train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)\n validation_dataloaders = {\n val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \\\n for val_name, val_dataset in validation_datasets.items()\n }\n\n return train_dataloader, validation_dataloaders\n\n\ntrain_dataloader, validation_dataloaders = prepare_dataloaders()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training function & evaluation function.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prepare pre-trained model and finetuning on downstream task.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import BertForSequenceClassification\n\n\ndef create_pretrained_model():\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n model.bert.config.output_hidden_states = True\n return model\n\n\ndef create_finetuned_model():\n finetuned_model = create_pretrained_model()\n finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'\n\n if finetuned_model_state_path.exists():\n finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))\n finetuned_model.to(device)\n elif dev_mode:\n pass\n else:\n steps_per_epoch = len(train_dataloader)\n training_epochs = 3\n optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\n def lr_lambda(current_step: int):\n return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))\n\n lr_scheduler = LambdaLR(optimizer, lr_lambda)\n training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,\n max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',\n save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)\n return finetuned_model\n\n\nfinetuned_model = create_finetuned_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pruning\nAccording to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.\nOf course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.\nSo in this section, we do pruning in stages.\n\nFirst, we prune the attention layer with MovementPruner.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"steps_per_epoch = len(train_dataloader)\n\n# Set training steps/epochs for pruning.\n\nif not dev_mode:\n total_epochs = 4\n total_steps = total_epochs * steps_per_epoch\n warmup_steps = 1 * steps_per_epoch\n cooldown_steps = 1 * steps_per_epoch\nelse:\n total_epochs = 1\n total_steps = 3\n warmup_steps = 1\n cooldown_steps = 1\n\n# Initialize evaluator used by MovementPruner.\n\nimport nni\nfrom nni.algorithms.compression.v2.pytorch import TorchEvaluator\n\nmovement_training = functools.partial(training, train_dataloader=train_dataloader,\n log_path=log_dir / 'movement_pruning.log',\n evaluation_func=evaluation_func, device=device)\ntraced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n if current_step < warmup_steps:\n return float(current_step) / warmup_steps\n return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))\n\ntraced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)\nevaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)\n\n# Apply block-soft-movement pruning on attention layers.\n# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.\n\nfrom nni.compression.pytorch.pruning import MovementPruner\n\nconfig_list = [{\n 'op_types': ['Linear'],\n 'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],\n 'sparsity': 0.1\n}]\n\npruner = MovementPruner(model=finetuned_model,\n config_list=config_list,\n evaluator=evaluator,\n training_epochs=total_epochs,\n training_steps=total_steps,\n warm_up_step=warmup_steps,\n cool_down_beginning_step=total_steps - cooldown_steps,\n regular_scale=10,\n movement_mode='soft',\n sparse_granularity='auto')\n_, attention_masks = pruner.compress()\npruner.show_pruned_weights()\n\ntorch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.\nNote that nni speedup don't support replacing attention module, so here we manully replace the attention module.\n\nIf the head is entire masked, physically prune it and create config_list for FFN pruning.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Retrain the attention pruned model with distillation.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if not dev_mode:\n total_epochs = 5\n total_steps = None\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 1\n distillation = False\n\nteacher_model = create_finetuned_model()\noptimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))\n\nlr_scheduler = LambdaLR(optimizer, lr_lambda)\nat_model_save_path = log_dir / 'attention_pruned_model_state.pth'\ntraining(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,\n max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,\n distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,\n evaluation_func=evaluation_func, device=device)\n\nif not dev_mode:\n attention_pruned_model.load_state_dict(torch.load(at_model_save_path))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.\nFinetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.\n\nNNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if not dev_mode:\n total_epochs = 7\n total_steps = None\n taylor_pruner_steps = 1000\n steps_per_iteration = 3000\n total_pruning_steps = 36000\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 6\n taylor_pruner_steps = 2\n steps_per_iteration = 2\n total_pruning_steps = 4\n distillation = False\n\nfrom nni.compression.pytorch.pruning import TaylorFOWeightPruner\nfrom nni.compression.pytorch.speedup import ModelSpeedup\n\ndistil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,\n teacher_model=teacher_model, distil_func=distil_func, device=device)\ntraced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\nevaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)\n\ncurrent_step = 0\nbest_result = 0\ninit_lr = 3e-5\n\ndummy_input = torch.rand(8, 128, 768).to(device)\n\nattention_pruned_model.train()\nfor current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if total_steps and current_step >= total_steps:\n break\n # pruning with TaylorFOWeightPruner & reinitialize optimizer\n if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:\n check_point = attention_pruned_model.state_dict()\n pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)\n _, ffn_masks = pruner.compress()\n renamed_ffn_masks = {}\n # rename the masks keys, because we only speedup the bert.encoder\n for model_name, targets_mask in ffn_masks.items():\n renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask\n pruner._unwrap_model()\n attention_pruned_model.load_state_dict(check_point)\n ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()\n optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)\n\n batch.to(device)\n # manually schedule lr\n for params_group in optimizer.param_groups:\n params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr\n\n outputs = attention_pruned_model(**batch)\n loss = outputs.loss\n\n # distillation\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n current_step += 1\n\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(attention_pruned_model)\n with (log_dir / 'ffn_pruning.log').open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())),\n current_epoch, current_step, result)\n f.write(msg)\n if current_step >= total_pruning_steps and best_result < result['default']:\n torch.save(attention_pruned_model, log_dir / 'best_model.pth')\n best_result = result['default']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Result\nThe speedup is test on the entire validation dataset with batch size 128 on A100.\nWe test under two pytorch version and found the latency varying widely.\n\nSetting 1: pytorch 1.12.1\n\nSetting 2: pytorch 1.10.0\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Attention Pruning Method\n - FFN Pruning Method\n - Total Sparsity\n - Accuracy\n - Acc. Drop\n - Speedup (S1)\n - Speedup (S2)\n * -\n -\n - 85.1M (-0.0%)\n - 84.85 / 85.28\n - +0.0 / +0.0\n - 25.60s (x1.00)\n - 8.10s (x1.00)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=1)\n - `taylor-fo-weight-pruner`\n - 54.1M (-36.43%)\n - 85.38 / 85.41\n - +0.53 / +0.13\n - 17.93s (x1.43)\n - 7.22s (x1.12)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=5)\n - `taylor-fo-weight-pruner`\n - 37.1M (-56.40%)\n - 84.73 / 85.12\n - -0.12 / -0.16\n - 12.83s (x2.00)\n - 5.61s (x1.44)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=10)\n - `taylor-fo-weight-pruner`\n - 24.1M (-71.68%)\n - 84.14 / 84.78\n - -0.71 / -0.50\n - 8.93s (x2.87)\n - 4.55s (x1.78)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=20)\n - `taylor-fo-weight-pruner`\n - 14.3M (-83.20%)\n - 83.26 / 82.96\n - -1.59 / -2.32\n - 5.98s (x4.28)\n - 3.56s (x2.28)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=30)\n - `taylor-fo-weight-pruner`\n - 9.9M (-88.37%)\n - 82.22 / 82.19\n - -2.63 / -3.09\n - 4.36s (x5.88)\n - 3.12s (x2.60)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=40)\n - `taylor-fo-weight-pruner`\n - 8.8M (-89.66%)\n - 81.64 / 82.39\n - -3.21 / -2.89\n - 3.88s (x6.60)\n - 2.81s (x2.88)\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
\ No newline at end of file
"""
Pruning Bert on Task MNLI
=========================
Workable Pruning Process
------------------------
Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.
The entire pruning process can be divided into the following steps:
1. Finetune the pre-trained model on the downstream task. From our experience,
the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.
At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following
distillation training.
2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,
and directly prune the head (condense the weight) if the head was fully masked.
If the head was partially masked, we will not prune it and recover its weight.
3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.
4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,
and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.
5. Retrain the final pruned model with distillation.
During the process of pruning transformer, we gained some of the following experiences:
* We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers,
and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,
we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario.
* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.
* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.
Experiment
----------
The complete pruning process will take about 8 hours on one A100.
Preparation
^^^^^^^^^^^
This section is mainly to get a finetuned model on the downstream task.
If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.
.. note::
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
"""
dev_mode = True
# %%
# Some basic setting.
from pathlib import Path
from typing import Callable, Dict
pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli'
experiment_id = 'pruning_bert_mnli'
# heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12
layers_num = 12
# used to save the experiment log
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
log_dir.mkdir(parents=True, exist_ok=True)
# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True)
# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)
# set seed
from transformers import set_seed
set_seed(1024)
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# %%
# Create dataloaders.
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
sentence1_key, sentence2_key = task_to_keys[task_name]
data_collator = DataCollatorWithPadding(tokenizer)
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
validation_dataloaders = {
val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
for val_name, val_dataset in validation_datasets.items()
}
return train_dataloader, validation_dataloaders
train_dataloader, validation_dataloaders = prepare_dataloaders()
# %%
# Training function & evaluation function.
import functools
import time
import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
max_steps: int = None,
max_epochs: int = None,
train_dataloader: DataLoader = None,
distillation: bool = False,
teacher_model: torch.nn.Module = None,
distil_func: Callable = None,
log_path: str = Path(log_dir) / 'training.log',
save_best_model: bool = False,
save_path: str = None,
evaluation_func: Callable = None,
eval_per_steps: int = 1000,
device=None):
assert train_dataloader is not None
model.train()
if teacher_model is not None:
teacher_model.eval()
current_step = 0
best_result = 0
total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)
print(f'Training {total_epochs} epochs, {total_steps} steps...')
for current_epoch in range(total_epochs):
for batch in train_dataloader:
if current_step >= total_steps:
return
batch.to(device)
outputs = model(**batch)
loss = outputs.loss
if distillation:
assert teacher_model is not None
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss
loss = criterion(loss, None)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# per step schedule
if lr_scheduler:
lr_scheduler.step()
current_step += 1
if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(model) if evaluation_func else None
with (log_path).open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
f.write(msg)
# if it's the best model, save it.
if save_best_model and (result is None or best_result < result['default']):
assert save_path is not None
torch.save(model.state_dict(), save_path)
best_result = None if result is None else result['default']
def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
encoder_hidden_state_loss = []
for i, idx in enumerate(encoder_layer_idxs[:-1]):
encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
distil_loss = 0
for loss in encoder_hidden_state_loss:
distil_loss += loss
distil_loss += logits_loss
return distil_loss
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
assert validation_dataloaders is not None
training = model.training
model.eval()
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
result = {}
default_result = 0
for val_name, validation_dataloader in validation_dataloaders.items():
for batch in validation_dataloader:
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch['labels'],
)
result[val_name] = metric.compute()
default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
result['default'] = default_result / len(result)
model.train(training)
return result
evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)
def fake_criterion(loss, _):
return loss
# %%
# Prepare pre-trained model and finetuning on downstream task.
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification
def create_pretrained_model():
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
model.bert.config.output_hidden_states = True
return model
def create_finetuned_model():
finetuned_model = create_pretrained_model()
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
if finetuned_model_state_path.exists():
finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
finetuned_model.to(device)
elif dev_mode:
pass
else:
steps_per_epoch = len(train_dataloader)
training_epochs = 3
optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda)
training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
return finetuned_model
finetuned_model = create_finetuned_model()
# %%
# Pruning
# ^^^^^^^
# According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.
# Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.
# So in this section, we do pruning in stages.
#
# First, we prune the attention layer with MovementPruner.
steps_per_epoch = len(train_dataloader)
# Set training steps/epochs for pruning.
if not dev_mode:
total_epochs = 4
total_steps = total_epochs * steps_per_epoch
warmup_steps = 1 * steps_per_epoch
cooldown_steps = 1 * steps_per_epoch
else:
total_epochs = 1
total_steps = 3
warmup_steps = 1
cooldown_steps = 1
# Initialize evaluator used by MovementPruner.
import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
movement_training = functools.partial(training, train_dataloader=train_dataloader,
log_path=log_dir / 'movement_pruning.log',
evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return float(current_step) / warmup_steps
return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))
traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
# Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.
from nni.compression.pytorch.pruning import MovementPruner
config_list = [{
'op_types': ['Linear'],
'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
'sparsity': 0.1
}]
pruner = MovementPruner(model=finetuned_model,
config_list=config_list,
evaluator=evaluator,
training_epochs=total_epochs,
training_steps=total_steps,
warm_up_step=warmup_steps,
cool_down_beginning_step=total_steps - cooldown_steps,
regular_scale=10,
movement_mode='soft',
sparse_granularity='auto')
_, attention_masks = pruner.compress()
pruner.show_pruned_weights()
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
# %%
# Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.
# Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.
#
# If the head is entire masked, physically prune it and create config_list for FFN pruning.
attention_pruned_model = create_finetuned_model().to(device)
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
ffn_config_list = []
layer_remained_idxs = []
module_list = []
for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idxs) != heads_num:
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
module_list.append(attention_pruned_model.bert.encoder.layer[i])
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
# This is just an empirical configuration, you can use any other method to determine this sparsity.
sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
ffn_config_list.append({
'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
'sparsity': sparsity_per_iter
})
layer_remained_idxs.append(i)
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
# %%
# Retrain the attention pruned model with distillation.
if not dev_mode:
total_epochs = 5
total_steps = None
distillation = True
else:
total_epochs = 1
total_steps = 1
distillation = False
teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
evaluation_func=evaluation_func, device=device)
if not dev_mode:
attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
# %%
# Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
# Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.
#
# NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
if not dev_mode:
total_epochs = 7
total_steps = None
taylor_pruner_steps = 1000
steps_per_iteration = 3000
total_pruning_steps = 36000
distillation = True
else:
total_epochs = 1
total_steps = 6
taylor_pruner_steps = 2
steps_per_iteration = 2
total_pruning_steps = 4
distillation = False
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup
distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
current_step = 0
best_result = 0
init_lr = 3e-5
dummy_input = torch.rand(8, 128, 768).to(device)
attention_pruned_model.train()
for current_epoch in range(total_epochs):
for batch in train_dataloader:
if total_steps and current_step >= total_steps:
break
# pruning with TaylorFOWeightPruner & reinitialize optimizer
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
check_point = attention_pruned_model.state_dict()
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
_, ffn_masks = pruner.compress()
renamed_ffn_masks = {}
# rename the masks keys, because we only speedup the bert.encoder
for model_name, targets_mask in ffn_masks.items():
renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
pruner._unwrap_model()
attention_pruned_model.load_state_dict(check_point)
ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
batch.to(device)
# manually schedule lr
for params_group in optimizer.param_groups:
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
outputs = attention_pruned_model(**batch)
loss = outputs.loss
# distillation
if distillation:
assert teacher_model is not None
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
current_step += 1
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(attention_pruned_model)
with (log_dir / 'ffn_pruning.log').open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
current_epoch, current_step, result)
f.write(msg)
if current_step >= total_pruning_steps and best_result < result['default']:
torch.save(attention_pruned_model, log_dir / 'best_model.pth')
best_result = result['default']
# %%
# Result
# ------
# The speedup is test on the entire validation dataset with batch size 128 on A100.
# We test under two pytorch version and found the latency varying widely.
#
# Setting 1: pytorch 1.12.1
#
# Setting 2: pytorch 1.10.0
#
# .. list-table:: Prune Bert-base-uncased on MNLI
# :header-rows: 1
# :widths: auto
#
# * - Attention Pruning Method
# - FFN Pruning Method
# - Total Sparsity
# - Accuracy
# - Acc. Drop
# - Speedup (S1)
# - Speedup (S2)
# * -
# -
# - 85.1M (-0.0%)
# - 84.85 / 85.28
# - +0.0 / +0.0
# - 25.60s (x1.00)
# - 8.10s (x1.00)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=1)
# - :ref:`taylor-fo-weight-pruner`
# - 54.1M (-36.43%)
# - 85.38 / 85.41
# - +0.53 / +0.13
# - 17.93s (x1.43)
# - 7.22s (x1.12)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=5)
# - :ref:`taylor-fo-weight-pruner`
# - 37.1M (-56.40%)
# - 84.73 / 85.12
# - -0.12 / -0.16
# - 12.83s (x2.00)
# - 5.61s (x1.44)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=10)
# - :ref:`taylor-fo-weight-pruner`
# - 24.1M (-71.68%)
# - 84.14 / 84.78
# - -0.71 / -0.50
# - 8.93s (x2.87)
# - 4.55s (x1.78)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=20)
# - :ref:`taylor-fo-weight-pruner`
# - 14.3M (-83.20%)
# - 83.26 / 82.96
# - -1.59 / -2.32
# - 5.98s (x4.28)
# - 3.56s (x2.28)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=30)
# - :ref:`taylor-fo-weight-pruner`
# - 9.9M (-88.37%)
# - 82.22 / 82.19
# - -2.63 / -3.09
# - 4.36s (x5.88)
# - 3.12s (x2.60)
# * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=40)
# - :ref:`taylor-fo-weight-pruner`
# - 8.8M (-89.66%)
# - 81.64 / 82.39
# - -3.21 / -2.89
# - 3.88s (x6.60)
# - 2.81s (x2.88)
d3191675dd9427c6906f2bd3929ee382
\ No newline at end of file
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/pruning_bert_glue.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_tutorials_pruning_bert_glue.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_pruning_bert_glue.py:
Pruning Bert on Task MNLI
=========================
Workable Pruning Process
------------------------
Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.
The entire pruning process can be divided into the following steps:
1. Finetune the pre-trained model on the downstream task. From our experience,
the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.
At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following
distillation training.
2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,
and directly prune the head (condense the weight) if the head was fully masked.
If the head was partially masked, we will not prune it and recover its weight.
3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.
4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,
and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.
5. Retrain the final pruned model with distillation.
During the process of pruning transformer, we gained some of the following experiences:
* We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers,
and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,
we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario.
* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.
* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.
Experiment
----------
The complete pruning process will take about 8 hours on one A100.
Preparation
^^^^^^^^^^^
This section is mainly to get a finetuned model on the downstream task.
If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.
.. note::
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
.. GENERATED FROM PYTHON SOURCE LINES 48-51
.. code-block:: default
dev_mode = True
.. GENERATED FROM PYTHON SOURCE LINES 52-53
Some basic setting.
.. GENERATED FROM PYTHON SOURCE LINES 53-84
.. code-block:: default
from pathlib import Path
from typing import Callable, Dict
pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli'
experiment_id = 'pruning_bert_mnli'
# heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12
layers_num = 12
# used to save the experiment log
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
log_dir.mkdir(parents=True, exist_ok=True)
# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True)
# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)
# set seed
from transformers import set_seed
set_seed(1024)
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
.. GENERATED FROM PYTHON SOURCE LINES 85-86
Create dataloaders.
.. GENERATED FROM PYTHON SOURCE LINES 86-152
.. code-block:: default
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
sentence1_key, sentence2_key = task_to_keys[task_name]
data_collator = DataCollatorWithPadding(tokenizer)
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
validation_dataloaders = {
val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
for val_name, val_dataset in validation_datasets.items()
}
return train_dataloader, validation_dataloaders
train_dataloader, validation_dataloaders = prepare_dataloaders()
.. GENERATED FROM PYTHON SOURCE LINES 153-154
Training function & evaluation function.
.. GENERATED FROM PYTHON SOURCE LINES 154-277
.. code-block:: default
import functools
import time
import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
max_steps: int = None,
max_epochs: int = None,
train_dataloader: DataLoader = None,
distillation: bool = False,
teacher_model: torch.nn.Module = None,
distil_func: Callable = None,
log_path: str = Path(log_dir) / 'training.log',
save_best_model: bool = False,
save_path: str = None,
evaluation_func: Callable = None,
eval_per_steps: int = 1000,
device=None):
assert train_dataloader is not None
model.train()
if teacher_model is not None:
teacher_model.eval()
current_step = 0
best_result = 0
total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)
print(f'Training {total_epochs} epochs, {total_steps} steps...')
for current_epoch in range(total_epochs):
for batch in train_dataloader:
if current_step >= total_steps:
return
batch.to(device)
outputs = model(**batch)
loss = outputs.loss
if distillation:
assert teacher_model is not None
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss
loss = criterion(loss, None)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# per step schedule
if lr_scheduler:
lr_scheduler.step()
current_step += 1
if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(model) if evaluation_func else None
with (log_path).open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
f.write(msg)
# if it's the best model, save it.
if save_best_model and (result is None or best_result < result['default']):
assert save_path is not None
torch.save(model.state_dict(), save_path)
best_result = None if result is None else result['default']
def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
encoder_hidden_state_loss = []
for i, idx in enumerate(encoder_layer_idxs[:-1]):
encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
distil_loss = 0
for loss in encoder_hidden_state_loss:
distil_loss += loss
distil_loss += logits_loss
return distil_loss
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
assert validation_dataloaders is not None
training = model.training
model.eval()
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
result = {}
default_result = 0
for val_name, validation_dataloader in validation_dataloaders.items():
for batch in validation_dataloader:
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch['labels'],
)
result[val_name] = metric.compute()
default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
result['default'] = default_result / len(result)
model.train(training)
return result
evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)
def fake_criterion(loss, _):
return loss
.. GENERATED FROM PYTHON SOURCE LINES 278-279
Prepare pre-trained model and finetuning on downstream task.
.. GENERATED FROM PYTHON SOURCE LINES 279-320
.. code-block:: default
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification
def create_pretrained_model():
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
model.bert.config.output_hidden_states = True
return model
def create_finetuned_model():
finetuned_model = create_pretrained_model()
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
if finetuned_model_state_path.exists():
finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
finetuned_model.to(device)
elif dev_mode:
pass
else:
steps_per_epoch = len(train_dataloader)
training_epochs = 3
optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda)
training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
return finetuned_model
finetuned_model = create_finetuned_model()
.. GENERATED FROM PYTHON SOURCE LINES 321-328
Pruning
^^^^^^^
According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.
Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.
So in this section, we do pruning in stages.
First, we prune the attention layer with MovementPruner.
.. GENERATED FROM PYTHON SOURCE LINES 328-388
.. code-block:: default
steps_per_epoch = len(train_dataloader)
# Set training steps/epochs for pruning.
if not dev_mode:
total_epochs = 4
total_steps = total_epochs * steps_per_epoch
warmup_steps = 1 * steps_per_epoch
cooldown_steps = 1 * steps_per_epoch
else:
total_epochs = 1
total_steps = 3
warmup_steps = 1
cooldown_steps = 1
# Initialize evaluator used by MovementPruner.
import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
movement_training = functools.partial(training, train_dataloader=train_dataloader,
log_path=log_dir / 'movement_pruning.log',
evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return float(current_step) / warmup_steps
return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))
traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
# Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.
from nni.compression.pytorch.pruning import MovementPruner
config_list = [{
'op_types': ['Linear'],
'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
'sparsity': 0.1
}]
pruner = MovementPruner(model=finetuned_model,
config_list=config_list,
evaluator=evaluator,
training_epochs=total_epochs,
training_steps=total_steps,
warm_up_step=warmup_steps,
cool_down_beginning_step=total_steps - cooldown_steps,
regular_scale=10,
movement_mode='soft',
sparse_granularity='auto')
_, attention_masks = pruner.compress()
pruner.show_pruned_weights()
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
.. GENERATED FROM PYTHON SOURCE LINES 389-393
Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.
Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.
If the head is entire masked, physically prune it and create config_list for FFN pruning.
.. GENERATED FROM PYTHON SOURCE LINES 393-423
.. code-block:: default
attention_pruned_model = create_finetuned_model().to(device)
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
ffn_config_list = []
layer_remained_idxs = []
module_list = []
for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idxs) != heads_num:
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
module_list.append(attention_pruned_model.bert.encoder.layer[i])
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
# This is just an empirical configuration, you can use any other method to determine this sparsity.
sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
ffn_config_list.append({
'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
'sparsity': sparsity_per_iter
})
layer_remained_idxs.append(i)
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
.. GENERATED FROM PYTHON SOURCE LINES 424-425
Retrain the attention pruned model with distillation.
.. GENERATED FROM PYTHON SOURCE LINES 425-451
.. code-block:: default
if not dev_mode:
total_epochs = 5
total_steps = None
distillation = True
else:
total_epochs = 1
total_steps = 1
distillation = False
teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int):
return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
evaluation_func=evaluation_func, device=device)
if not dev_mode:
attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
.. GENERATED FROM PYTHON SOURCE LINES 452-456
Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.
NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
.. GENERATED FROM PYTHON SOURCE LINES 456-537
.. code-block:: default
if not dev_mode:
total_epochs = 7
total_steps = None
taylor_pruner_steps = 1000
steps_per_iteration = 3000
total_pruning_steps = 36000
distillation = True
else:
total_epochs = 1
total_steps = 6
taylor_pruner_steps = 2
steps_per_iteration = 2
total_pruning_steps = 4
distillation = False
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup
distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
current_step = 0
best_result = 0
init_lr = 3e-5
dummy_input = torch.rand(8, 128, 768).to(device)
attention_pruned_model.train()
for current_epoch in range(total_epochs):
for batch in train_dataloader:
if total_steps and current_step >= total_steps:
break
# pruning with TaylorFOWeightPruner & reinitialize optimizer
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
check_point = attention_pruned_model.state_dict()
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
_, ffn_masks = pruner.compress()
renamed_ffn_masks = {}
# rename the masks keys, because we only speedup the bert.encoder
for model_name, targets_mask in ffn_masks.items():
renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
pruner._unwrap_model()
attention_pruned_model.load_state_dict(check_point)
ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
batch.to(device)
# manually schedule lr
for params_group in optimizer.param_groups:
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
outputs = attention_pruned_model(**batch)
loss = outputs.loss
# distillation
if distillation:
assert teacher_model is not None
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
current_step += 1
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(attention_pruned_model)
with (log_dir / 'ffn_pruning.log').open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
current_epoch, current_step, result)
f.write(msg)
if current_step >= total_pruning_steps and best_result < result['default']:
torch.save(attention_pruned_model, log_dir / 'best_model.pth')
best_result = result['default']
.. GENERATED FROM PYTHON SOURCE LINES 538-607
Result
------
The speedup is test on the entire validation dataset with batch size 128 on A100.
We test under two pytorch version and found the latency varying widely.
Setting 1: pytorch 1.12.1
Setting 2: pytorch 1.10.0
.. list-table:: Prune Bert-base-uncased on MNLI
:header-rows: 1
:widths: auto
* - Attention Pruning Method
- FFN Pruning Method
- Total Sparsity
- Accuracy
- Acc. Drop
- Speedup (S1)
- Speedup (S2)
* -
-
- 85.1M (-0.0%)
- 84.85 / 85.28
- +0.0 / +0.0
- 25.60s (x1.00)
- 8.10s (x1.00)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=1)
- :ref:`taylor-fo-weight-pruner`
- 54.1M (-36.43%)
- 85.38 / 85.41
- +0.53 / +0.13
- 17.93s (x1.43)
- 7.22s (x1.12)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=5)
- :ref:`taylor-fo-weight-pruner`
- 37.1M (-56.40%)
- 84.73 / 85.12
- -0.12 / -0.16
- 12.83s (x2.00)
- 5.61s (x1.44)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=10)
- :ref:`taylor-fo-weight-pruner`
- 24.1M (-71.68%)
- 84.14 / 84.78
- -0.71 / -0.50
- 8.93s (x2.87)
- 4.55s (x1.78)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=20)
- :ref:`taylor-fo-weight-pruner`
- 14.3M (-83.20%)
- 83.26 / 82.96
- -1.59 / -2.32
- 5.98s (x4.28)
- 3.56s (x2.28)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=30)
- :ref:`taylor-fo-weight-pruner`
- 9.9M (-88.37%)
- 82.22 / 82.19
- -2.63 / -3.09
- 4.36s (x5.88)
- 3.12s (x2.60)
* - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=40)
- :ref:`taylor-fo-weight-pruner`
- 8.8M (-89.66%)
- 81.64 / 82.39
- -3.21 / -2.89
- 3.88s (x6.60)
- 2.81s (x2.88)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 20.822 seconds)
.. _sphx_glr_download_tutorials_pruning_bert_glue.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: pruning_bert_glue.py <pruning_bert_glue.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: pruning_bert_glue.ipynb <pruning_bert_glue.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Customize Basic Pruner\n\nUsers can easily customize a basic pruner in NNI. A large number of basic modules have been provided and can be reused.\nFollow the NNI pruning interface, users only need to focus on their creative parts without worrying about other regular modules.\n\nIn this tutorial, we show how to customize a basic pruner.\n\n## Concepts\n\nNNI abstracts the basic pruning process into three steps, collecting data, calculating metrics, allocating sparsity.\nMost pruning algorithms rely on a metric to decide where should be pruned. Using L1 norm pruner as an example,\nthe first step is collecting model weights, the second step is calculating L1 norm for weight per output channel,\nthe third step is ranking L1 norm metric and masking the output channels that have small L1 norm.\n\nIn NNI basic pruner, these three step is implement as ``DataCollector``, ``MetricsCalculator`` and ``SparsityAllocator``.\n\n- ``DataCollector``: This module take pruner as initialize parameter.\n It will get the relevant information of the model from the pruner,\n and sometimes it will also hook the model to get input, output or gradient of a layer or a tensor.\n It can also patch optimizer if some special steps need to be executed before or after ``optimizer.step()``.\n\n- ``MetricsCalculator``: This module will take the data collected from the ``DataCollector``,\n then calculate the metrics. The metric shape is usually reduced from the data shape.\n The ``dim`` taken by ``MetricsCalculator`` means which dimension will be kept after calculate metrics.\n i.e., the collected data shape is (10, 20, 30), and the ``dim`` is 1, then the dimension-1 will be kept,\n the output metrics shape should be (20,).\n\n- ``SparsityAllocator``: This module take the metrics and generate the masks.\n Different ``SparsityAllocator`` has different masks generation strategies.\n A common and simple strategy is sorting the metrics' values and calculating a threshold according to the configured sparsity,\n mask the positions which metric value smaller than the threshold.\n The ``dim`` taken by ``SparsityAllocator`` means the metrics are for which dimension, the mask will be expanded to weight shape.\n i.e., the metric shape is (20,), the corresponding layer weight shape is (20, 40), and the ``dim`` is 0.\n ``SparsityAllocator`` will first generate a mask with shape (20,), then expand this mask to shape (20, 40).\n\n## Simple Example: Customize a Block-L1NormPruner\n\nNNI already have L1NormPruner, but for the reason of reproducing the paper and reducing user configuration items,\nit only support pruning layer output channels. In this example, we will customize a pruner that supports block granularity for Linear.\n\nNote that you don't need to implement all these three kinds of tools for each time,\nNNI supports many predefined tools, and you can directly use these to customize your own pruner.\nThis is a tutorial so we show how to define all these three kinds of pruning tools.\n\nCustomize the pruning tools used by the pruner at first.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\nfrom nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner\nfrom nni.algorithms.compression.v2.pytorch.pruning.tools import (\n DataCollector,\n MetricsCalculator,\n SparsityAllocator\n)\n\n\n# This data collector collects weight in wrapped module as data.\n# The wrapped module is the module configured in pruner's config_list.\n# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.WeightDataCollector\nclass WeightDataCollector(DataCollector):\n def collect(self):\n data = {}\n # get_modules_wrapper will get all the wrapper in the compressor (pruner),\n # it returns a dict with format {wrapper_name: wrapper},\n # use wrapper.module to get the wrapped module.\n for _, wrapper in self.compressor.get_modules_wrapper().items():\n data[wrapper.name] = wrapper.module.weight.data\n # return {wrapper_name: weight_data}\n return data\n\n\nclass BlockNormMetricsCalculator(MetricsCalculator):\n def __init__(self, block_sparse_size):\n # Because we will keep all dimension with block granularity, so fix ``dim=None``,\n # means all dimensions will be kept.\n super().__init__(dim=None, block_sparse_size=block_sparse_size)\n\n def calculate_metrics(self, data):\n data_length = len(self.block_sparse_size)\n reduce_unfold_dims = list(range(data_length, 2 * data_length))\n\n metrics = {}\n for name, t in data.items():\n # Unfold t as block size, and calculate L1 Norm for each block.\n for dim, size in enumerate(self.block_sparse_size):\n t = t.unfold(dim, size, size)\n metrics[name] = t.norm(dim=reduce_unfold_dims, p=1)\n # return {wrapper_name: block_metric}\n return metrics\n\n\n# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.NormalSparsityAllocator\nclass BlockSparsityAllocator(SparsityAllocator):\n def __init__(self, pruner, block_sparse_size):\n super().__init__(pruner, dim=None, block_sparse_size=block_sparse_size, continuous_mask=True)\n\n def generate_sparsity(self, metrics):\n masks = {}\n for name, wrapper in self.pruner.get_modules_wrapper().items():\n # wrapper.config['total_sparsity'] can get the configured sparsity ratio for this wrapped module\n sparsity_rate = wrapper.config['total_sparsity']\n # get metric for this wrapped module\n metric = metrics[name]\n # mask the metric with old mask, if the masked position need never recover,\n # just keep this is ok if you are new in NNI pruning\n if self.continuous_mask:\n metric *= self._compress_mask(wrapper.weight_mask)\n # convert sparsity ratio to prune number\n prune_num = int(sparsity_rate * metric.numel())\n # calculate the metric threshold\n threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()\n # generate mask, keep the metric positions that metric values greater than the threshold\n mask = torch.gt(metric, threshold).type_as(metric)\n # expand the mask to weight size, if the block is masked, this block will be filled with zeros,\n # otherwise filled with ones\n masks[name] = self._expand_mask(name, mask)\n # merge the new mask with old mask, if the masked position need never recover,\n # just keep this is ok if you are new in NNI pruning\n if self.continuous_mask:\n masks[name]['weight'] *= wrapper.weight_mask\n return masks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Customize the pruner.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class BlockL1NormPruner(BasicPruner):\n def __init__(self, model, config_list, block_sparse_size):\n self.block_sparse_size = block_sparse_size\n super().__init__(model, config_list)\n\n # Implement reset_tools is enough for this pruner.\n def reset_tools(self):\n if self.data_collector is None:\n self.data_collector = WeightDataCollector(self)\n else:\n self.data_collector.reset()\n if self.metrics_calculator is None:\n self.metrics_calculator = BlockNormMetricsCalculator(self.block_sparse_size)\n if self.sparsity_allocator is None:\n self.sparsity_allocator = BlockSparsityAllocator(self, self.block_sparse_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try this pruner.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Define a simple model.\nclass TestModel(torch.nn.Module):\n def __init__(self) -> None:\n super().__init__()\n self.fc1 = torch.nn.Linear(4, 8)\n self.fc2 = torch.nn.Linear(8, 4)\n\n def forward(self, x):\n return self.fc2(self.fc1(x))\n\nmodel = TestModel()\nconfig_list = [{'op_types': ['Linear'], 'total_sparsity': 0.5}]\n# use 2x2 block\n_, masks = BlockL1NormPruner(model, config_list, [2, 2]).compress()\n\n# show the generated masks\nprint('fc1 masks:\\n', masks['fc1']['weight'])\nprint('fc2 masks:\\n', masks['fc2']['weight'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This time we successfully define a new pruner with pruning block granularity!\nNote that we don't put validation logic in this example, like ``_validate_config_before_canonical``,\nbut for a robust implementation, we suggest you involve the validation logic.\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
\ No newline at end of file
"""
Customize Basic Pruner
======================
Users can easily customize a basic pruner in NNI. A large number of basic modules have been provided and can be reused.
Follow the NNI pruning interface, users only need to focus on their creative parts without worrying about other regular modules.
In this tutorial, we show how to customize a basic pruner.
Concepts
--------
NNI abstracts the basic pruning process into three steps, collecting data, calculating metrics, allocating sparsity.
Most pruning algorithms rely on a metric to decide where should be pruned. Using L1 norm pruner as an example,
the first step is collecting model weights, the second step is calculating L1 norm for weight per output channel,
the third step is ranking L1 norm metric and masking the output channels that have small L1 norm.
In NNI basic pruner, these three step is implement as ``DataCollector``, ``MetricsCalculator`` and ``SparsityAllocator``.
- ``DataCollector``: This module take pruner as initialize parameter.
It will get the relevant information of the model from the pruner,
and sometimes it will also hook the model to get input, output or gradient of a layer or a tensor.
It can also patch optimizer if some special steps need to be executed before or after ``optimizer.step()``.
- ``MetricsCalculator``: This module will take the data collected from the ``DataCollector``,
then calculate the metrics. The metric shape is usually reduced from the data shape.
The ``dim`` taken by ``MetricsCalculator`` means which dimension will be kept after calculate metrics.
i.e., the collected data shape is (10, 20, 30), and the ``dim`` is 1, then the dimension-1 will be kept,
the output metrics shape should be (20,).
- ``SparsityAllocator``: This module take the metrics and generate the masks.
Different ``SparsityAllocator`` has different masks generation strategies.
A common and simple strategy is sorting the metrics' values and calculating a threshold according to the configured sparsity,
mask the positions which metric value smaller than the threshold.
The ``dim`` taken by ``SparsityAllocator`` means the metrics are for which dimension, the mask will be expanded to weight shape.
i.e., the metric shape is (20,), the corresponding layer weight shape is (20, 40), and the ``dim`` is 0.
``SparsityAllocator`` will first generate a mask with shape (20,), then expand this mask to shape (20, 40).
Simple Example: Customize a Block-L1NormPruner
----------------------------------------------
NNI already have L1NormPruner, but for the reason of reproducing the paper and reducing user configuration items,
it only support pruning layer output channels. In this example, we will customize a pruner that supports block granularity for Linear.
Note that you don't need to implement all these three kinds of tools for each time,
NNI supports many predefined tools, and you can directly use these to customize your own pruner.
This is a tutorial so we show how to define all these three kinds of pruning tools.
Customize the pruning tools used by the pruner at first.
"""
import torch
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
DataCollector,
MetricsCalculator,
SparsityAllocator
)
# This data collector collects weight in wrapped module as data.
# The wrapped module is the module configured in pruner's config_list.
# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.WeightDataCollector
class WeightDataCollector(DataCollector):
def collect(self):
data = {}
# get_modules_wrapper will get all the wrapper in the compressor (pruner),
# it returns a dict with format {wrapper_name: wrapper},
# use wrapper.module to get the wrapped module.
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data
# return {wrapper_name: weight_data}
return data
class BlockNormMetricsCalculator(MetricsCalculator):
def __init__(self, block_sparse_size):
# Because we will keep all dimension with block granularity, so fix ``dim=None``,
# means all dimensions will be kept.
super().__init__(dim=None, block_sparse_size=block_sparse_size)
def calculate_metrics(self, data):
data_length = len(self.block_sparse_size)
reduce_unfold_dims = list(range(data_length, 2 * data_length))
metrics = {}
for name, t in data.items():
# Unfold t as block size, and calculate L1 Norm for each block.
for dim, size in enumerate(self.block_sparse_size):
t = t.unfold(dim, size, size)
metrics[name] = t.norm(dim=reduce_unfold_dims, p=1)
# return {wrapper_name: block_metric}
return metrics
# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.NormalSparsityAllocator
class BlockSparsityAllocator(SparsityAllocator):
def __init__(self, pruner, block_sparse_size):
super().__init__(pruner, dim=None, block_sparse_size=block_sparse_size, continuous_mask=True)
def generate_sparsity(self, metrics):
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
# wrapper.config['total_sparsity'] can get the configured sparsity ratio for this wrapped module
sparsity_rate = wrapper.config['total_sparsity']
# get metric for this wrapped module
metric = metrics[name]
# mask the metric with old mask, if the masked position need never recover,
# just keep this is ok if you are new in NNI pruning
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
# convert sparsity ratio to prune number
prune_num = int(sparsity_rate * metric.numel())
# calculate the metric threshold
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
# generate mask, keep the metric positions that metric values greater than the threshold
mask = torch.gt(metric, threshold).type_as(metric)
# expand the mask to weight size, if the block is masked, this block will be filled with zeros,
# otherwise filled with ones
masks[name] = self._expand_mask(name, mask)
# merge the new mask with old mask, if the masked position need never recover,
# just keep this is ok if you are new in NNI pruning
if self.continuous_mask:
masks[name]['weight'] *= wrapper.weight_mask
return masks
# %%
# Customize the pruner.
class BlockL1NormPruner(BasicPruner):
def __init__(self, model, config_list, block_sparse_size):
self.block_sparse_size = block_sparse_size
super().__init__(model, config_list)
# Implement reset_tools is enough for this pruner.
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = BlockNormMetricsCalculator(self.block_sparse_size)
if self.sparsity_allocator is None:
self.sparsity_allocator = BlockSparsityAllocator(self, self.block_sparse_size)
# %%
# Try this pruner.
# Define a simple model.
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(4, 8)
self.fc2 = torch.nn.Linear(8, 4)
def forward(self, x):
return self.fc2(self.fc1(x))
model = TestModel()
config_list = [{'op_types': ['Linear'], 'total_sparsity': 0.5}]
# use 2x2 block
_, masks = BlockL1NormPruner(model, config_list, [2, 2]).compress()
# show the generated masks
print('fc1 masks:\n', masks['fc1']['weight'])
print('fc2 masks:\n', masks['fc2']['weight'])
# %%
# This time we successfully define a new pruner with pruning block granularity!
# Note that we don't put validation logic in this example, like ``_validate_config_before_canonical``,
# but for a robust implementation, we suggest you involve the validation logic.
5b92fe6666938105b07998c198077299
\ No newline at end of file
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/pruning_customize.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_tutorials_pruning_customize.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_pruning_customize.py:
Customize Basic Pruner
======================
Users can easily customize a basic pruner in NNI. A large number of basic modules have been provided and can be reused.
Follow the NNI pruning interface, users only need to focus on their creative parts without worrying about other regular modules.
In this tutorial, we show how to customize a basic pruner.
Concepts
--------
NNI abstracts the basic pruning process into three steps, collecting data, calculating metrics, allocating sparsity.
Most pruning algorithms rely on a metric to decide where should be pruned. Using L1 norm pruner as an example,
the first step is collecting model weights, the second step is calculating L1 norm for weight per output channel,
the third step is ranking L1 norm metric and masking the output channels that have small L1 norm.
In NNI basic pruner, these three step is implement as ``DataCollector``, ``MetricsCalculator`` and ``SparsityAllocator``.
- ``DataCollector``: This module take pruner as initialize parameter.
It will get the relevant information of the model from the pruner,
and sometimes it will also hook the model to get input, output or gradient of a layer or a tensor.
It can also patch optimizer if some special steps need to be executed before or after ``optimizer.step()``.
- ``MetricsCalculator``: This module will take the data collected from the ``DataCollector``,
then calculate the metrics. The metric shape is usually reduced from the data shape.
The ``dim`` taken by ``MetricsCalculator`` means which dimension will be kept after calculate metrics.
i.e., the collected data shape is (10, 20, 30), and the ``dim`` is 1, then the dimension-1 will be kept,
the output metrics shape should be (20,).
- ``SparsityAllocator``: This module take the metrics and generate the masks.
Different ``SparsityAllocator`` has different masks generation strategies.
A common and simple strategy is sorting the metrics' values and calculating a threshold according to the configured sparsity,
mask the positions which metric value smaller than the threshold.
The ``dim`` taken by ``SparsityAllocator`` means the metrics are for which dimension, the mask will be expanded to weight shape.
i.e., the metric shape is (20,), the corresponding layer weight shape is (20, 40), and the ``dim`` is 0.
``SparsityAllocator`` will first generate a mask with shape (20,), then expand this mask to shape (20, 40).
Simple Example: Customize a Block-L1NormPruner
----------------------------------------------
NNI already have L1NormPruner, but for the reason of reproducing the paper and reducing user configuration items,
it only support pruning layer output channels. In this example, we will customize a pruner that supports block granularity for Linear.
Note that you don't need to implement all these three kinds of tools for each time,
NNI supports many predefined tools, and you can directly use these to customize your own pruner.
This is a tutorial so we show how to define all these three kinds of pruning tools.
Customize the pruning tools used by the pruner at first.
.. GENERATED FROM PYTHON SOURCE LINES 51-128
.. code-block:: default
import torch
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
DataCollector,
MetricsCalculator,
SparsityAllocator
)
# This data collector collects weight in wrapped module as data.
# The wrapped module is the module configured in pruner's config_list.
# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.WeightDataCollector
class WeightDataCollector(DataCollector):
def collect(self):
data = {}
# get_modules_wrapper will get all the wrapper in the compressor (pruner),
# it returns a dict with format {wrapper_name: wrapper},
# use wrapper.module to get the wrapped module.
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data
# return {wrapper_name: weight_data}
return data
class BlockNormMetricsCalculator(MetricsCalculator):
def __init__(self, block_sparse_size):
# Because we will keep all dimension with block granularity, so fix ``dim=None``,
# means all dimensions will be kept.
super().__init__(dim=None, block_sparse_size=block_sparse_size)
def calculate_metrics(self, data):
data_length = len(self.block_sparse_size)
reduce_unfold_dims = list(range(data_length, 2 * data_length))
metrics = {}
for name, t in data.items():
# Unfold t as block size, and calculate L1 Norm for each block.
for dim, size in enumerate(self.block_sparse_size):
t = t.unfold(dim, size, size)
metrics[name] = t.norm(dim=reduce_unfold_dims, p=1)
# return {wrapper_name: block_metric}
return metrics
# This implementation is similar as nni.algorithms.compression.v2.pytorch.pruning.tools.NormalSparsityAllocator
class BlockSparsityAllocator(SparsityAllocator):
def __init__(self, pruner, block_sparse_size):
super().__init__(pruner, dim=None, block_sparse_size=block_sparse_size, continuous_mask=True)
def generate_sparsity(self, metrics):
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
# wrapper.config['total_sparsity'] can get the configured sparsity ratio for this wrapped module
sparsity_rate = wrapper.config['total_sparsity']
# get metric for this wrapped module
metric = metrics[name]
# mask the metric with old mask, if the masked position need never recover,
# just keep this is ok if you are new in NNI pruning
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
# convert sparsity ratio to prune number
prune_num = int(sparsity_rate * metric.numel())
# calculate the metric threshold
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
# generate mask, keep the metric positions that metric values greater than the threshold
mask = torch.gt(metric, threshold).type_as(metric)
# expand the mask to weight size, if the block is masked, this block will be filled with zeros,
# otherwise filled with ones
masks[name] = self._expand_mask(name, mask)
# merge the new mask with old mask, if the masked position need never recover,
# just keep this is ok if you are new in NNI pruning
if self.continuous_mask:
masks[name]['weight'] *= wrapper.weight_mask
return masks
.. GENERATED FROM PYTHON SOURCE LINES 129-130
Customize the pruner.
.. GENERATED FROM PYTHON SOURCE LINES 130-148
.. code-block:: default
class BlockL1NormPruner(BasicPruner):
def __init__(self, model, config_list, block_sparse_size):
self.block_sparse_size = block_sparse_size
super().__init__(model, config_list)
# Implement reset_tools is enough for this pruner.
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = BlockNormMetricsCalculator(self.block_sparse_size)
if self.sparsity_allocator is None:
self.sparsity_allocator = BlockSparsityAllocator(self, self.block_sparse_size)
.. GENERATED FROM PYTHON SOURCE LINES 149-150
Try this pruner.
.. GENERATED FROM PYTHON SOURCE LINES 150-171
.. code-block:: default
# Define a simple model.
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(4, 8)
self.fc2 = torch.nn.Linear(8, 4)
def forward(self, x):
return self.fc2(self.fc1(x))
model = TestModel()
config_list = [{'op_types': ['Linear'], 'total_sparsity': 0.5}]
# use 2x2 block
_, masks = BlockL1NormPruner(model, config_list, [2, 2]).compress()
# show the generated masks
print('fc1 masks:\n', masks['fc1']['weight'])
print('fc2 masks:\n', masks['fc2']['weight'])
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
fc1 masks:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
fc2 masks:
tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.]])
.. GENERATED FROM PYTHON SOURCE LINES 172-175
This time we successfully define a new pruner with pruning block granularity!
Note that we don't put validation logic in this example, like ``_validate_config_before_canonical``,
but for a robust implementation, we suggest you involve the validation logic.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 1.175 seconds)
.. _sphx_glr_download_tutorials_pruning_customize.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: pruning_customize.py <pruning_customize.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: pruning_customize.ipynb <pruning_customize.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment