Unverified Commit aa2cc922 authored by Bill Wu's avatar Bill Wu Committed by GitHub
Browse files

Transformer Head Pruner (#3884)

parent 370e88df
...@@ -91,6 +91,8 @@ Pruners ...@@ -91,6 +91,8 @@ Pruners
.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.transformer_pruner.TransformerHeadPruner
:members:
Quantizers Quantizers
^^^^^^^^^^ ^^^^^^^^^^
......
...@@ -35,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms. ...@@ -35,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms.
Pruning Algorithms Pruning Algorithms
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-tting issue. Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
...@@ -73,6 +73,8 @@ Pruning algorithms compress the original network by removing redundant weights o ...@@ -73,6 +73,8 @@ Pruning algorithms compress the original network by removing redundant weights o
- Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner `Reference Paper <https://arxiv.org/abs/1907.03141>`__ - Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner `Reference Paper <https://arxiv.org/abs/1907.03141>`__
* - `AMC Pruner <../Compression/Pruner.rst#amc-pruner>`__ * - `AMC Pruner <../Compression/Pruner.rst#amc-pruner>`__
- AMC: AutoML for Model Compression and Acceleration on Mobile Devices `Reference Paper <https://arxiv.org/pdf/1802.03494.pdf>`__ - AMC: AutoML for Model Compression and Acceleration on Mobile Devices `Reference Paper <https://arxiv.org/pdf/1802.03494.pdf>`__
* - `Transformer Head Pruner <../Compression/Pruner.rst#transformer-head-pruner>`__
- Pruning attention heads from transformer models either in one shot or iteratively.
You can refer to this `benchmark <../CommunitySharings/ModelCompressionComparison.rst>`__ for the performance of these pruners on some benchmark problems. You can refer to this `benchmark <../CommunitySharings/ModelCompressionComparison.rst>`__ for the performance of these pruners on some benchmark problems.
......
...@@ -28,6 +28,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a ...@@ -28,6 +28,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
**Others** **Others**
* `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__ * `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__
* `Transformer Head Pruner <#transformer-head-pruner>`__
Level Pruner Level Pruner
------------ ------------
...@@ -724,3 +725,95 @@ User configuration for Sensitivity Pruner ...@@ -724,3 +725,95 @@ User configuration for Sensitivity Pruner
**PyTorch** **PyTorch**
.. autoclass:: nni.algorithms.compression.pytorch.pruning.SensitivityPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.SensitivityPruner
Transformer Head Pruner
-----------------------
Transformer Head Pruner is a tool designed for pruning attention heads from the models belonging to the `Transformer family <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__. The following image from `Efficient Transformers: A Survey <https://arxiv.org/pdf/2009.06732.pdf>`__ gives a good overview the general structure of the Transformer.
.. image:: ../../img/transformer_structure.png
:target: ../../img/transformer_structure.png
:alt:
Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together.
Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field.
The pruner implements the following algorithm:
.. code-block:: bash
Repeat for each pruning iteration (1 for one-shot pruning):
1. Calculate importance scores for each head in each specified layer using a specific criterion.
2. Sort heads locally or globally, and prune out some heads with lowest scores. The number of pruned heads is determined according to the sparsity specified in the config.
3. If the specified pruning iteration is larger than 1 (iterative pruning), finetune the model for a while before the next pruning iteration.
Currently, the following head sorting criteria are supported:
* "l1_weight": rank heads by the L1-norm of weights of the query, key, and value projection matrices.
* "l2_weight": rank heads by the L2-norm of weights of the query, key, and value projection matrices.
* "l1_activation": rank heads by the L1-norm of their attention computation output.
* "l2_activation": rank heads by the L2-norm of their attention computation output.
* "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also note that weights belonging to the same layer must have the same sparsity.
In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`.
Usage
^^^^^
Suppose we want to prune a BERT with Huggingface implementation, which has the following architecture (obtained by calling ``print(model)``). Note that we only show the first layer of the repeated layers in the encoder's ``ModuleList layer``.
.. image:: ../../img/huggingface_bert_architecture.png
:target: ../../img/huggingface_bert_architecture.png
:alt:
**Usage Example: one-shot pruning, assigning sparsity 0.5 to the first six layers and sparsity 0.25 to the last six layers (PyTorch code)**. Note that
* Here we specify ``op_names`` in the config list to assign different sparsity to different layers.
* Meanwhile, we pass ``attention_name_groups`` to the pruner so that the pruner may group together the weights belonging to the same attention layer.
* Since in this example we want to do one-shot pruning, the ``num_iterations`` parameter is set to 1, and the parameter ``epochs_per_iteration`` is ignored. If you would like to do iterative pruning instead, you can set the ``num_iterations`` parameter to the number of pruning iterations, and the ``epochs_per_iteration`` parameter to the number of finetuning epochs between two iterations.
* The arguments ``trainer`` and ``optimizer`` are only used when we want to do iterative pruning, or the ranking criterion is ``taylorfo``. Here these two parameters are ignored by the pruner.
* The argument ``forward_runner`` is only used when the ranking criterion is ``l1_activation`` or ``l2_activation``. Here this parameter is ignored by the pruner.
.. code-block:: python
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
attention_name_groups = list(zip(["encoder.layer.{}.attention.self.query".format(i) for i in range(12)],
["encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
["encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))
kwargs = {"ranking_criterion": "l1_weight",
"global_sort": False,
"num_iterations": 1,
"epochs_per_iteration": 1, # this is ignored when num_iterations = 1
"head_hidden_dim": 64,
"attention_name_groups": attention_name_groups,
"trainer": trainer,
"optimizer": optimizer,
"forward_runner": forward_runner
}
config_list = [{
"sparsity": 0.5,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers
},
{
"sparsity": 0.25,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
}
]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
User configuration for Transformer Head Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
.. autoclass:: nni.algorithms.compression.pytorch.pruning.TransformerHeadPruner
#!/bin/bash
# Usage: ./run.sh gpu_id glue_task
export CUDA_VISIBLE_DEVICES=$1
TASK_NAME=$2 # "cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli", "rte", "wnli"
PRETRAINED_MODEL="bert-base-uncased" # "distilbert-base-uncased", "roberta-base", "bert-base-cased", ...
# parameters for pruning
SPARSITY=0.5
RANKING_CRITERION=l1_weight # "l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"
NUM_ITERATIONS=1 # 1 for one-shot pruning
EPOCHS_PER_ITERATION=1
# other training parameters, no need to change
MAX_LENGTH=128
BATCH_SIZE=32
LR=2e-5
N_EPOCHS=3
time=$(date "+%Y%m%d%H%M%S")
OUTDIR="models_${PRETRAINED_MODEL}_${TASK_NAME}_$time/"
TASK_LIST=("cola" "sst2" "mrpc" "stsb" "qqp" "mnli" "qnli" "rte" "wnli")
if [[ ${TASK_LIST[*]} =~ (^|[[:space:]])$TASK_NAME($|[[:space:]]) ]]; then
mkdir $OUTDIR
python transformer_pruning.py \
--sparsity $SPARSITY \
--ranking_criterion $RANKING_CRITERION \
--num_iterations $NUM_ITERATIONS \
--epochs_per_iteration $EPOCHS_PER_ITERATION \
--speed_up \
--model_name $PRETRAINED_MODEL \
--task_name $TASK_NAME \
--max_length $MAX_LENGTH \
--batch_size $BATCH_SIZE \
--learning_rate $LR \
--num_train_epochs $N_EPOCHS \
--output_dir $OUTDIR \
2>&1 | tee "$OUTDIR/output.log"
else
echo "Unsupported task $TASK_NAME."
fi
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import math
import os
import random
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import nni
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
import datasets
from datasets import load_dataset, load_metric
import transformers
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
PretrainedConfig,
default_data_collator,
get_scheduler,
)
logger = logging.getLogger("bert_pruning_example")
def parse_args():
parser = argparse.ArgumentParser(description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")
parser.add_argument("--model_name", type=str, required=True,
help="Pretrained model architecture.")
parser.add_argument("--task_name", type=str, default=None,
help="The name of the GLUE task.",
choices=["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"])
parser.add_argument("--output_dir", type=str, default=None,
help="Where to store the model and mask.")
parser.add_argument("--sparsity", type=float, required=True,
help="Sparsity: proportion of heads to prune (should be between 0 and 1)")
parser.add_argument("--global_sort", action="store_true", default=False,
help="Rank the heads globally and prune the heads with lowest scores. If set to False, the "
"heads are only ranked within one layer")
parser.add_argument("--ranking_criterion", type=str, default="l1_weight",
choices=["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"],
help="Criterion by which the attention heads are ranked.")
parser.add_argument("--num_iterations", type=int, default=1,
help="Number of pruning iterations (1 for one-shot pruning).")
parser.add_argument("--epochs_per_iteration", type=int, default=1,
help="Epochs to finetune before the next pruning iteration "
"(only effective if num_iterations > 1).")
parser.add_argument("--speed_up", action="store_true", default=False,
help="Whether to speed-up the pruned model")
# parameters for model training; no need to change them for running examples
parser.add_argument("--max_length", type=int, default=128,
help=("The maximum total input sequence length after tokenization. Sequences longer than this "
"will be truncated, sequences shorter will be padded if `--pad_to_max_lengh` is passed."))
parser.add_argument("--batch_size", type=int, default=8,
help="Batch size.")
parser.add_argument("--learning_rate", type=float, default=5e-5,
help="Initial learning rate.")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Total number of training epochs to perform.")
parser.add_argument("--lr_scheduler_type", default="linear",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant",
"constant_with_warmup"])
parser.add_argument("--num_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler.")
args = parser.parse_args()
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def get_raw_dataset(task_name):
"""
Get a GLUE dataset using huggingface datasets.
"""
raw_dataset = load_dataset("glue", task_name)
is_regression = task_name == "stsb"
num_labels = 1 if is_regression else len(raw_dataset["train"].features["label"].names)
return raw_dataset, is_regression, num_labels
def preprocess(args, tokenizer, raw_dataset):
"""
Tokenization and column renaming.
"""
assert args.task_name is not None
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[args.task_name]
def tokenize(data):
texts = (
(data[sentence1_key],) if sentence2_key is None else (data[sentence1_key], data[sentence2_key])
)
result = tokenizer(*texts, padding=False, max_length=args.max_length, truncation=True)
if "label" in data:
result["labels"] = data["label"]
return result
processed_datasets = raw_dataset.map(tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
return processed_datasets
def get_dataloader_and_optimizer(args, tokenizer, model, train_dataset, eval_dataset):
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator,
batch_size=args.batch_size)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator,
batch_size=args.batch_size)
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
return optimizer, train_dataloader, eval_dataloader, data_collator
def train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device):
"""
Train the model using train_dataloader and evaluate after every epoch using eval_dataloader.
This function is called before and after pruning for "pretraining" on the GLUE task and further "finetuning".
"""
train_steps = args.num_train_epochs * len(train_dataloader)
progress_bar = tqdm(range(train_steps), position=0, leave=True)
for epoch in range(args.num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
for field in batch.keys():
batch[field] = batch[field].to(device)
outputs = model(**batch)
outputs.loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
model.eval()
for step, batch in enumerate(eval_dataloader):
for field in batch.keys():
batch[field] = batch[field].to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(predictions=predictions, references=batch["labels"])
eval_metric = metric.compute()
logger.info(f"epoch {epoch}: {eval_metric}")
def trainer_helper(model, train_dataloader, optimizer, device):
"""
This function is used for to create a "trainer" that is passed to the pruner.
Finetune the model for 1 epoch. This function is called by the pruner during pruning iterations (or called to
calculate scores for pruning when ranking criterion is "taylorfo").
"""
logger.info("Training for 1 epoch...")
progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True)
train_epoch = 1
for epoch in range(train_epoch):
for step, batch in enumerate(train_dataloader):
for field in batch.keys():
batch[field] = batch[field].to(device)
outputs = model(**batch)
outputs.loss.backward()
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
def forward_runner_helper(model, train_dataloader, device):
"""
This function is used for to create a "forward_runner" that is passed to the pruner.
The function just runs forward on the train set without updating the parameters.
This allows the pruner to collect data for activation-based pruning methods.
"""
logger.info("Running forward on the entire train set without updating parameters...")
progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True)
forward_epoch = 1
for epoch in range(forward_epoch):
for step, batch in enumerate(train_dataloader):
for field in batch.keys():
batch[field] = batch[field].to(device)
_ = model(**batch)
# note: no loss.backward or optimizer.step() is performed here
progress_bar.update(1)
def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator):
"""
If the task is MNLI, perform a final evaluation on mismatched validation set
"""
eval_dataset = processed_datasets["validation_mismatched"]
eval_dataloader = DataLoader(
eval_dataset, collate_fn=data_collator, batch_size=args.batch_size
)
model.eval()
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=predictions,
references=batch["labels"],
)
eval_metric = metric.compute()
logger.info(f"mnli-mm: {eval_metric}")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parse_args()
#########################################################################
# Prepare model, tokenizer, dataset, optimizer, and the scheduler
logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
# Load dataset and tokenizer, and then preprocess the dataset
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
processed_datasets = preprocess(args, tokenizer, raw_dataset)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
# Load pretrained model
config = AutoConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
model.to(device)
#########################################################################
# Finetune on the target GLUE task before pruning
optimizer, train_dataloader, eval_dataloader, data_collator = get_dataloader_and_optimizer(args, tokenizer,
model,
train_dataset,
eval_dataset)
train_steps = args.num_train_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps)
metric = load_metric("glue", args.task_name)
logger.info("================= Finetuning before pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt")
if args.task_name == "mnli":
final_eval_for_mnli(args, model, processed_datasets, metric, data_collator)
#########################################################################
# Pruning
optimizer, train_dataloader, eval_dataloader, data_collator = get_dataloader_and_optimizer(args, tokenizer,
model,
train_dataset,
eval_dataset)
dummy_input = next(iter(train_dataloader))["input_ids"].to(device)
flops, params, results = count_flops_params(model, dummy_input)
print(f"Initial model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M")
# Here criterion is embedded in the model. Upper levels can just pass None to trainer.
def trainer(model, optimizer, criterion, epoch):
return trainer_helper(model, train_dataloader, optimizer, device)
def forward_runner(model):
return forward_runner_helper(model, train_dataloader, device)
# example: prune different layers with different sparsity
attention_name_groups = list(zip(["bert.encoder.layer.{}.attention.self.query".format(i) for i in range(12)],
["bert.encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
["bert.encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
["bert.encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))
kwargs = {"ranking_criterion": args.ranking_criterion,
"global_sort": args.global_sort,
"num_iterations": args.num_iterations,
"epochs_per_iteration": args.epochs_per_iteration,
"attention_name_groups": attention_name_groups,
"head_hidden_dim": 64,
"trainer": trainer,
"optimizer": optimizer,
"forward_runner": forward_runner}
config_list = [{
"sparsity": args.sparsity,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer]
},
{
"sparsity": args.sparsity / 2,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer]
}
]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
#########################################################################
# uncomment the following part to export the pruned model masks
# model_path = os.path.join(args.output_dir, "pruned_{}_{}.pth".format(args.model_name, args.task_name))
# mask_path = os.path.join(args.output_dir, "mask_{}_{}.pth".format(args.model_name, args.task_name))
# pruner.export_model(model_path=model_path, mask_path=mask_path)
#########################################################################
# Speedup
# Currently, speeding up Transformers through NNI ModelSpeedup is not supported because of shape inference issues.
# However, if you are using the transformers library, you can use the following workaround:
# The following code gets the head pruning decisions from the pruner and calls the _prune_heads() function
# implemented in models from the transformers library to speed up the model.
if args.speed_up:
speedup_rules = {}
for group_idx, group in enumerate(pruner.attention_name_groups):
# get the layer index
layer_idx = None
for part in group[0].split("."):
try:
layer_idx = int(part)
break
except:
continue
if layer_idx is not None:
speedup_rules[layer_idx] = pruner.pruned_heads[group_idx]
pruner._unwrap_model()
model.bert._prune_heads(speedup_rules)
print(model)
#########################################################################
# After pruning, finetune again on the target task
# Get the metric function
metric = load_metric("glue", args.task_name)
# re-initialize the optimizer and the scheduler
optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset,
eval_dataset)
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps)
logger.info("================= Finetuning after Pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_after_pruning.pt")
if args.task_name == "mnli":
final_eval_for_mnli(args, model, processed_datasets, metric, data_collator)
flops, params, results = count_flops_params(model, dummy_input)
print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M")
if __name__ == "__main__":
main()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from .finegrained_pruning_masker import * from .finegrained_pruning_masker import *
from .structured_pruning_masker import * from .structured_pruning_masker import *
from .transformer_pruning_head_masker import *
from .one_shot_pruner import * from .one_shot_pruner import *
from .iterative_pruner import * from .iterative_pruner import *
from .lottery_ticket import LotteryTicketPruner from .lottery_ticket import LotteryTicketPruner
...@@ -11,3 +12,4 @@ from .net_adapt_pruner import NetAdaptPruner ...@@ -11,3 +12,4 @@ from .net_adapt_pruner import NetAdaptPruner
from .auto_compress_pruner import AutoCompressPruner from .auto_compress_pruner import AutoCompressPruner
from .sensitivity_pruner import SensitivityPruner from .sensitivity_pruner import SensitivityPruner
from .amc import AMCPruner from .amc import AMCPruner
from .transformer_pruner import TransformerHeadPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from schema import And, Optional
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import AttentionWeightDependency
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Pruner
from . import L1WeightHeadMasker, L2WeightHeadMasker, L1ActivationHeadMasker, L2ActivationHeadMasker, TaylorFOHeadMasker
__all__ = ['TransformerHeadPruner']
MASKER_DICT = {
'l1_weight': L1WeightHeadMasker,
'l2_weight': L2WeightHeadMasker,
'l1_activation': L1ActivationHeadMasker,
'l2_activation': L2ActivationHeadMasker,
'taylorfo': TaylorFOHeadMasker
}
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class TransformerHeadPruner(Pruner):
"""
A pruner specialized for pruning attention heads in models belong to the transformer family.
Parameters
----------
model : torch.nn.Module
Model to be pruned. Expect a model from transformers library (e.g., BertModel).
This pruner can work with other customized transformer models, but some ranking modes might fail.
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Optional. Operation types to prune. (Should be 'Linear' for this pruner.)
- op_names : Optional. Operation names to prune.
head_hidden_dim : int
Dimension of the hidden dimension of each attention head. (e.g., 64 for BERT)
We assume that this head_hidden_dim is constant across the entire model.
attention_name_groups : list (Optional)
List of groups of names for weights of each attention layer. Each element should be a four-element list, with
the first three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.
dummy_input : torch.Tensor (Optional)
Input to model's forward method, used to infer module grouping if attention_name_groups is not specified.
This tensor is used by the underlying torch.jit.trace to infer the module graph.
ranking_criterion : str
The criterion for ranking attention heads. Currently we support:
- l1_weight: l1 norm of Q_proj, K_proj, and V_proj
- l2_weight: l2 norm of Q_proj, K_proj, and V_proj
- l1_activation: l1 norm of the output of attention computation
- l2_activation: l2 norm of the output of attention computation
- taylorfo: l1 norm of the output of attention computation * gradient for this output
(check more details in the masker documentation)
global_sort : bool
Whether rank the heads globally or locally before deciding heads to prune.
num_iterations : int
Number of pruning iterations. Defaults to 1 (ont-shot pruning). If num_iterations > 1, the pruner will split
the sparsity specified in config_list uniformly and assign a fraction to each pruning iteration.
epochs_per_iteration : int
Number of finetuning epochs before the next pruning iteration.
Only used when num_iterations > 1.
If num_iterations is 1, then no finetuning is performed by the pruner after pruning.
optimizer: torch.optim.Optimizer
Optimizer used to train model
trainer: function
Function used to finetune the model between pruning iterations.
Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'.
Users should write this function as a normal function to train the PyTorch model and include
`model, optimizer, criterion, epoch` as function arguments. Note that the trainer is also used for collecting
gradients for pruning if ranking_criterion is 'taylorfo'. In that case, ``epoch=None`` will be passed.
criterion: function
Function used to calculate the loss between the target and the output.
Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
forward_runner: function
Function used to perform a "dry run" on the model on the entire train/validation dataset in order to collect
data for pruning required by the criteria 'l1_activation' or 'l2_activation'.
Only used when ranking_criterion is 'l1_activation' or 'l2_activation'.
Users should write this function as a normal function that accepts a PyTorch model and runs forward on the model
using the entire train/validation dataset. This function is not expected to perform any backpropagation or
parameter updates.
"""
def __init__(self, model, config_list,head_hidden_dim, attention_name_groups=None, dummy_input=None,
ranking_criterion='l1_weight', global_sort=False, num_iterations=1, epochs_per_iteration=1,
optimizer=None, trainer=None, criterion=None, forward_runner=None,
**algo_kwargs):
super().__init__(model, config_list)
self.head_hidden_dim = int(head_hidden_dim)
self.attention_name_groups = attention_name_groups
self.dummy_input = dummy_input
self.ranking_criterion = ranking_criterion
assert self.ranking_criterion in ['l1_weight', 'l2_weight', 'l1_activation', 'l2_activation', 'taylorfo'], \
"Unsupported ranking criteria."
self.global_sort = global_sort
self.num_iterations = int(num_iterations)
assert self.num_iterations >= 1, "num_iterations must be greater than or equal to 1"
self.epochs_per_iteration = int(epochs_per_iteration)
self._optimizer = optimizer
self._trainer = trainer
self._criterion = criterion
self._forward_runner = forward_runner
if self.ranking_criterion in ['taylorfo'] or num_iterations > 1:
assert self._trainer is not None
assert self._optimizer is not None
if self.ranking_criterion in ['l1_activation', 'l2_activation']:
assert self._forward_runner is not None
# Group generation: one group per attention layer, four weights per group
self.masking_groups = []
if self.attention_name_groups is not None:
logger.info("Note: weights for the same attention layer are grouped using the given attention_name_groups.")
self.group_weights_by_name()
else:
assert self.dummy_input is not None
logger.info("Note: weights for the same attention layer are grouped using model graph.")
self._unwrap_model()
self.group_weight_names_by_graph()
self._wrap_model()
# Group sanity check
self.validate_weight_groups()
# Remove any mistakenly captured ungrouped modules
self._unwrap_model()
self.remove_ungrouped_modules()
self._wrap_model()
self.masker = MASKER_DICT[ranking_criterion](model, self, self.head_hidden_dim, **algo_kwargs)
self.pruned_heads = {i: set() for i in range(len(self.masking_groups))}
def group_weights_by_name(self):
"""
Populate self.masking_groups using the groups specified by user in attention_name_groups.
"""
assert len(self.masking_groups) == 0
# build up masking groups
name2group = {}
for layer_idx, layer in enumerate(self.attention_name_groups):
errmsg = 'Each name group must contain 4 weights, with the first three corresponding to Q_proj, K_proj, ' \
'V_proj (in any order) and the last one being output_proj.'
assert len(layer) == 4, errmsg
self.masking_groups.append([])
for weight in layer:
name2group[weight] = layer_idx
# group wrappers
for wrapper in self.get_modules_wrapper():
if wrapper.name in name2group:
wrapper.group_idx = name2group[wrapper.name]
self.masking_groups[name2group[wrapper.name]].append(wrapper)
logger.info('Grouping updated:')
logger.info([[x.name for x in group] for group in self.masking_groups])
def group_weight_names_by_graph(self):
"""
Populate self.attention_name_groups by running inference on the module graph.
Currently, the group inferred AttentionWeightDependency is limited to a set of four weights, with the first
three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.
"""
try:
module_graph = TorchModuleGraph(self.bound_model, self.dummy_input)
dependency_tracer = AttentionWeightDependency(traced_model=module_graph.trace)
self.attention_name_groups = dependency_tracer.dependency_sets
self.group_weights_by_name()
except Exception as e:
raise RuntimeError('Graph trace failed: please check dummy_input, or specify attention_name_groups.\n'
'Exception message: ' + str(e))
def validate_weight_groups(self):
"""
Sanity checks:
- Q, K, V projection weights in each groups must have the same shape
- output projection weight shape must match total hidden dimension (inferred from Q, K, V projection)
- Four weights in a group must have the same sparsity in their config
- If global_sort is specified, all weights must have the same sparsity
- head_hidden_dim must be a divisor of the output dimension of the projection weights (i.e., the resulting
head number must be an integer)
"""
errmsg = 'Attention weight group sanity check not passed'
sparsity = None
for group in self.masking_groups:
# allow empty groups - may be caused by config list filtering
if len(group) == 0:
continue
assert len(group) == 4, errmsg + ': each group must have four weights'
assert group[0].module.weight.size() == group[1].module.weight.size() and \
group[1].module.weight.size() == group[2].module.weight.size(), \
errmsg + ': the dimensions of Q, K, V projection matrices must be the same '
assert group[0].module.weight.size()[0] == group[3].module.weight.size()[1], \
errmsg + ': the dimension of attention results must match with input for output projection'
assert group[0].config['sparsity'] == group[1].config['sparsity'] == \
group[2].config['sparsity'] == group[3].config['sparsity'], \
errmsg + ': the sparsity of matrices in the same layer must be the same'
if sparsity is None:
sparsity = group[0].config['sparsity']
if self.global_sort:
assert sparsity == group[0].config['sparsity'], \
errmsg + ': for global_sort=True, the sparsity for all modules must be the same'
assert group[0].module.weight.size(0) % self.head_hidden_dim == 0, \
errmsg + ': head_hidden_dim must be a divisor of the output dimension of the projection weights'
def remove_ungrouped_modules(self):
"""
Remove non-attention weights that might be mistakenly captured by a simplified config_list.
Also update the corresponding list of layer information (self.modules_to_compress)
"""
care_of_modules = set([x for layer in self.masking_groups for x in layer])
modules_wrapper_new, modules_to_compress_new = [], []
for wrapper, layer_info in zip(self.modules_wrapper, self.modules_to_compress):
if wrapper in care_of_modules:
modules_wrapper_new.append(wrapper)
modules_to_compress_new.append(layer_info)
self.modules_wrapper = modules_wrapper_new
self.modules_to_compress = modules_to_compress_new
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def compress(self):
for pruning_iter in range(self.num_iterations):
if self.ranking_criterion in ['l1_activation', 'l2_activation']:
training = self.bound_model.training
self.bound_model.eval()
self._forward_runner(self.bound_model) # dry run, forward only
self.update_mask()
self.bound_model.train(training)
elif self.ranking_criterion in ['taylorfo']:
self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=None)
self.update_mask()
else:
self.update_mask()
# for iterative pruning, if not the last iteration, finetune before next iteration
# Then, reset the maskers (may create additional hooks)
if self.num_iterations > 1 and pruning_iter != self.num_iterations - 1:
for e in range(self.epochs_per_iteration):
self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=e+1)
self.masker.reset()
logger.info('Pruned heads after iteration %i', pruning_iter)
logger.info(self.pruned_heads)
def update_mask(self):
"""
Calculate and update masks for each masking group. If global_sort is set, the masks for all groups are
calculated altogether, and then the groups are updated individually.
"""
masks_for_all_groups = None
if self.global_sort:
masks_for_all_groups = self._calc_mask_global()
assert len(masks_for_all_groups) == len(self.masking_groups)
for group_idx, layer_weight_group in enumerate(self.masking_groups):
if self.global_sort:
masks = masks_for_all_groups[group_idx]
else:
masks = self._calc_mask(layer_weight_group)
if masks is not None:
for i, mask in enumerate(masks):
for mask_type in mask:
assert hasattr(layer_weight_group[i], mask_type), \
"there is no attribute '%s' in wrapper on %s" % (mask_type, layer_weight_group[i])
setattr(layer_weight_group[i], mask_type, mask[mask_type])
logger.debug(f'mask updated: {layer_weight_group[i].name} {mask_type}')
def _calc_mask(self, weight_group):
"""
Calculate mask for each group using only layer-local information.
When global_sort is set for the pruner, _calc_mask_global should be called instead of this function.
Parameters
----------
weight_group : list
A list of four wrappers generated by self.group_weights_by_name().
Returns
-------
masks : list
A four element list corresponding to the masks for each element in the four-element weight group.
Each element in masks is a dict with keys "weight_mask" and "bias_mask" (optional).
masks can be None if the underlying masker returns None. This means that the mask calculation fails.
The calling function can try recalculate the mask at a later time. Note that the calling function might need
to call masker.reset() before attempting to recalculate the mask.
"""
iter_sparsity = weight_group[0].config['sparsity'] / self.num_iterations
masks = self.masker.calc_mask(sparsity=iter_sparsity, weight_group=weight_group)
return masks
def _calc_mask_global(self):
"""
Calculate mask for all groups using global information.
Returns
-------
masks_list : list
A list corresponding to the masks for each weight group in self.masking_groups. Each element in the
returned mask_list is a four-element list corresponding to the masks for each element in a four-element
weight group.
"""
if len(self.get_modules_wrapper()) == 0:
return []
overall_sparsity = self.get_modules_wrapper()[0].config['sparsity'] / self.num_iterations
n_heads_total = 0
for group in self.masking_groups:
if len(group) != 0:
q_proj, _, _, _ = group
n_heads_total += int(q_proj.module.weight.size()[0] / self.head_hidden_dim)
n_heads_to_prune = int(n_heads_total * overall_sparsity)
return self.masker.calc_mask_global(n_heads_to_prune)
def calc_mask(self, wrapper, **kwargs):
raise RuntimeError("Applications should directly call TransformerHeadPruner's update_mask() method.")
...@@ -6,7 +6,8 @@ import logging ...@@ -6,7 +6,8 @@ import logging
import numpy as np import numpy as np
__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency'] __all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency', 'AttentionWeightDependency']
CONV_TYPE = 'aten::_convolution' CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_'] ADD_TYPES = ['aten::add', 'aten::add_']
...@@ -88,7 +89,6 @@ class ChannelDependency(Dependency): ...@@ -88,7 +89,6 @@ class ChannelDependency(Dependency):
""" """
This model analyze the channel dependencies between the conv This model analyze the channel dependencies between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -105,12 +105,10 @@ class ChannelDependency(Dependency): ...@@ -105,12 +105,10 @@ class ChannelDependency(Dependency):
def _get_parent_layers(self, node): def _get_parent_layers(self, node):
""" """
Find the nearest father conv layers for the target node. Find the nearest father conv layers for the target node.
Parameters Parameters
--------- ---------
node : torch._C.Node node : torch._C.Node
target node. target node.
Returns Returns
------- -------
parent_layers: list parent_layers: list
...@@ -182,7 +180,6 @@ class ChannelDependency(Dependency): ...@@ -182,7 +180,6 @@ class ChannelDependency(Dependency):
means the output channel(filters) numbers of these means the output channel(filters) numbers of these
three layers should be same with each other, otherwise three layers should be same with each other, otherwise
the model may has shape conflict. the model may has shape conflict.
Output example: Output example:
Dependency Set,Convolutional Layers Dependency Set,Convolutional Layers
Set 1,layer1.1.conv2,layer1.0.conv2,conv1 Set 1,layer1.1.conv2,layer1.0.conv2,conv1
...@@ -219,7 +216,6 @@ class ChannelDependency(Dependency): ...@@ -219,7 +216,6 @@ class ChannelDependency(Dependency):
dependency_sets : list dependency_sets : list
list of the dependency sets. For example, list of the dependency sets. For example,
[set(['conv1', 'conv2']), set(['conv3', 'conv4'])] [set(['conv1', 'conv2']), set(['conv3', 'conv4'])]
""" """
d_sets = [] d_sets = []
visited = set() visited = set()
...@@ -256,7 +252,6 @@ class InputChannelDependency(ChannelDependency): ...@@ -256,7 +252,6 @@ class InputChannelDependency(ChannelDependency):
""" """
This model analyze the input channel dependencies between the conv This model analyze the input channel dependencies between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -319,7 +314,6 @@ class GroupDependency(Dependency): ...@@ -319,7 +314,6 @@ class GroupDependency(Dependency):
""" """
This model analyze the group dependencis between the conv This model analyze the group dependencis between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -336,12 +330,10 @@ class GroupDependency(Dependency): ...@@ -336,12 +330,10 @@ class GroupDependency(Dependency):
def _get_parent_convs(self, node): def _get_parent_convs(self, node):
""" """
Find the nearest father conv layers for the target node. Find the nearest father conv layers for the target node.
Parameters Parameters
--------- ---------
node : torch._C.Node node : torch._C.Node
target node. target node.
Returns Returns
------- -------
parent_layers : list parent_layers : list
...@@ -369,12 +361,10 @@ class GroupDependency(Dependency): ...@@ -369,12 +361,10 @@ class GroupDependency(Dependency):
def _get_conv_groups(self, node_group): def _get_conv_groups(self, node_group):
""" """
Get the number of groups for a convolutional layer. Get the number of groups for a convolutional layer.
Parameters Parameters
---------- ----------
node_group : NodePyGroup node_group : NodePyGroup
target node. target node.
Returns Returns
------- -------
group : int group : int
...@@ -401,7 +391,7 @@ class GroupDependency(Dependency): ...@@ -401,7 +391,7 @@ class GroupDependency(Dependency):
conv2 takes the output features of conv1 as input. conv2 takes the output features of conv1 as input.
Then we have to the filters of conv1 can still be Then we have to the filters of conv1 can still be
divided into 4 groups after filter pruning, because divided into 4 groups after filter pruning, because
the input channels of conv2 shoule be divided into the input channels of conv2 should be divided into
4 groups. 4 groups.
Returns Returns
...@@ -448,7 +438,6 @@ class GroupDependency(Dependency): ...@@ -448,7 +438,6 @@ class GroupDependency(Dependency):
line is the group count of the filters in this layer. line is the group count of the filters in this layer.
Note that, the group count may be larger than this Note that, the group count may be larger than this
layers original group number. layers original group number.
output example: output example:
Conv layer, Groups Conv layer, Groups
Conv1, 1 Conv1, 1
...@@ -468,7 +457,6 @@ class GroupDependency(Dependency): ...@@ -468,7 +457,6 @@ class GroupDependency(Dependency):
return self.dependency return self.dependency
class ReshapeDependency(Dependency): class ReshapeDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -573,3 +561,142 @@ class ReshapeDependency(Dependency): ...@@ -573,3 +561,142 @@ class ReshapeDependency(Dependency):
d_sets.extend(self.dependency[reshape_node]) d_sets.extend(self.dependency[reshape_node])
d_sets = list(set(d_sets)) d_sets = list(set(d_sets))
return d_sets return d_sets
class AttentionWeightDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Groups the linear layers belonging to the same attention layer in a model.
Currently, we only capture weights in attention layers with forward computations written
as four Linear layers (projections for Q, K, V, and output) and two matmul operations.
The method implemented here can work for Huggingface transformers but may not correctly
capture transformers written in other fashions (e.g., torch.nn.Transformer).
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
dummy_input : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we already have the traced graph of the target model, we do not
need to trace the model again.
"""
super(AttentionWeightDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node):
"""
Find the nearest parent linear layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest parent linear layers for the target worknode.
"""
parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Linear':
if curnode.name not in parent_layers:
parent_layers.append(curnode.name)
continue
if curnode.op_type == 'LayerNorm':
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
def _get_children_layers(self, node):
"""
Find the nearest children linear layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
children_layers: list
nearest children linear layers for the target worknode.
"""
children_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Linear':
if curnode.name not in children_layers:
children_layers.append(curnode.name)
continue
if curnode.op_type == 'LayerNorm':
continue
children = self.graph.find_successors(curnode.unique_name)
children = [self.graph.name_to_node[name] for name in children]
for child in children:
queue.append(child)
return children_layers
def build_dependency(self):
"""
For every matmul operation, find the immediate parent and children Linear operations.
If we get three parents and one children, add these four weights as a dependecy group.
"""
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
layers = []
if node.op_type == 'aten::matmul':
parent_layers = self._get_parent_layers(node)
children_layers = self._get_children_layers(node)
if len(parent_layers) == 3 and len(children_layers) == 1:
layers.extend(parent_layers)
layers.extend(children_layers)
self.dependency[node.name] = layers
@property
def dependency_sets(self):
"""
Get the list of the dependency set.
Returns
-------
dependency_sets : list
list of the dependency sets.
Each dependency set is a 4-element list of module names, with the first three elements being the projection
matrices for Q, K, V (in any order), and the last element being the dense matrix.
"""
d_sets = []
for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'aten::matmul' or node.name not in self.dependency or len(self.dependency[node.name]) != 4:
continue
d_sets.append(self.dependency[node.name])
return d_sets
def export(self, filepath):
"""
Export the group dependency to a csv file. Each line describes an attention layer.
Output example:
Attention layer matmul op, Group
"""
header = ['Attention layer matmul op', 'Group']
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for name in self.dependency:
group = self.dependency[name]
if len(group) > 0:
csv_w.writerow([name, group])
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
class PosEncoding(nn.Module):
def __init__(self, hidden_dim, max_seq_len=80):
super().__init__()
self.hidden_dim = hidden_dim
pe = torch.zeros(max_seq_len, hidden_dim)
for pos in range(max_seq_len):
for i in range(0, hidden_dim, 2):
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / hidden_dim)))
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / hidden_dim)))
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x * math.sqrt(self.hidden_dim)
x = x + torch.autograd.Variable(self.pe[:, :x.size(1)], requires_grad=False)
return x
def attention(query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
logits = logits.masked_fill(mask == 0, -1e9)
attention_map = F.softmax(logits, dim=-1)
if dropout is not None:
attention_map = dropout(attention_map)
return torch.matmul(attention_map, value)
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, n_heads, dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // n_heads
self.n_heads = n_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.output_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# project and reshaping
k_project = self.k_proj(key)
q_project = self.q_proj(query)
v_project = self.v_proj(value)
k_reshape = k_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
q_reshape = q_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
v_reshape = v_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
# merge heads and output
scores = attention(q_reshape, k_reshape, v_reshape, mask, self.dropout)
scores = scores.transpose(1, 2).contiguous()
scores = scores.view(batch_size, -1, self.hidden_dim)
return self.output_proj(scores)
class FeedForwardLayer(nn.Module):
def __init__(self, hidden_dim, intermediate_dim=2048, dropout=0.1):
super().__init__()
self.dense1 = nn.Linear(hidden_dim, intermediate_dim)
self.dense2 = nn.Linear(intermediate_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dense2(self.dropout(F.relu(self.dense1(x))))
class LayerNorm(nn.Module):
def __init__(self, hidden_dim, eps=1e-6):
super(LayerNorm, self).__init__()
self.alpha = nn.Parameter(torch.ones(hidden_dim))
self.beta = nn.Parameter(torch.zeros(hidden_dim))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.alpha * (x - mean) / (std + self.eps) + self.beta
class TransformerEncoderLayer(nn.Module):
def __init__(self, n_heads, hidden_dim, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(hidden_dim, n_heads)
self.ff_layer = FeedForwardLayer(hidden_dim)
self.norm1 = LayerNorm(hidden_dim)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = LayerNorm(hidden_dim)
self.dropout2 = nn.Dropout(dropout)
def forward(self, inp, mask):
x = self.norm1(inp)
x = inp + self.dropout1(self.self_attn(x, x, x, mask))
x = x + self.dropout2(self.ff_layer(self.norm2(x)))
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, n_heads, hidden_dim, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(hidden_dim, n_heads)
self.cross_attn = MultiHeadAttention(hidden_dim, n_heads)
self.ff = FeedForwardLayer(hidden_dim)
self.norm1 = LayerNorm(hidden_dim)
self.norm2 = LayerNorm(hidden_dim)
self.norm3 = LayerNorm(hidden_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, inp, mask, encoder_output, encoder_output_mask):
x = self.norm1(inp)
x = inp + self.dropout1(self.self_attn(x, x, x, mask))
x = x + self.dropout2(self.cross_attn(self.norm2(x), encoder_output, encoder_output, encoder_output_mask))
x = x + self.dropout3(self.ff(self.norm3(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, n_layers, hidden_dim, n_heads):
super().__init__()
self.n_layers = n_layers
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.posencoding = PosEncoding(hidden_dim)
self.layers = nn.ModuleList([copy.deepcopy(TransformerEncoderLayer(n_heads, hidden_dim)) for _ in range(n_layers)])
self.layernorm = LayerNorm(hidden_dim)
def forward(self, src, mask):
x = self.embedding(src)
x = self.posencoding(x)
for i in range(self.n_layers):
x = self.layers[i](x, mask)
return self.layernorm(x)
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, n_layers, hidden_dim, n_heads):
super().__init__()
self.n_layers = n_layers
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.posencoding = PosEncoding(hidden_dim)
self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderLayer(n_heads, hidden_dim)) for _ in range(n_layers)])
self.layernorm = LayerNorm(hidden_dim)
def forward(self, inp, mask, encoder_output, encoder_output_mask):
x = self.embedding(inp)
x = self.posencoding(x)
for i in range(self.n_layers):
x = self.layers[i](x, mask, encoder_output, encoder_output_mask)
return self.layernorm(x)
class TransformerForSeq2Seq(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, n_layers, hidden_dim, n_heads):
super().__init__()
self.encoder = TransformerEncoder(src_vocab_size, n_layers, hidden_dim, n_heads)
self.decoder = TransformerDecoder(tgt_vocab_size, n_layers, hidden_dim, n_heads)
self.output_dense = nn.Linear(hidden_dim, tgt_vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask):
encoder_outputs = self.encoder(src, src_mask)
decoder_outputs = self.decoder(tgt, tgt_mask, encoder_outputs, src_mask)
return self.output_dense(decoder_outputs)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import math
import sys
import unittest
from unittest import TestCase, main
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
sys.path.append(os.path.dirname(__file__))
from models.pytorch_models.transformer import TransformerEncoder
def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask]
if bias and wrapper.bias_mask is not None:
masks.append(wrapper.bias_mask)
for m in masks:
actual_sparsity = (m == 0).sum().item() / m.numel()
msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity)
assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg
class Model(nn.Module):
"""
A binary classifier using a transformer encoder for contextual embedding.
"""
def __init__(self, n_layer, hidden_dim, n_head):
super(Model, self).__init__()
self.embedding = TransformerEncoder(vocab_size=100, hidden_dim=hidden_dim, n_layers=n_layer, n_heads=n_head)
self.classifier = nn.Linear(hidden_dim, 1)
def forward(self, x, mask):
raw_output = self.embedding(x, mask)
pooled_output = raw_output[0]
prediction = F.sigmoid(self.classifier(pooled_output)).squeeze()
return prediction
def train(model, dataloader, criterion, optimizer):
model.train()
device = next(model.parameters()).device
for _ in range(2):
y = torch.ones(10).to(device)
out = model(torch.randint(0, 100, (4, 10)).to(device), torch.ones(10).to(device))
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def dry_run(model):
device = next(model.parameters()).device
for _ in range(2):
y = torch.ones(10).to(device)
_ = model(torch.randint(0, 100, (4, 10)).to(device), torch.ones(10).to(device))
def head_pruner_tests(criterion, global_sort, use_graph, iterative):
print("Testing criterion {} with global_sort={} and use_graph={}".format(criterion, global_sort, use_graph))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Build config list and arguments
config_list = [{'sparsity': 0.5, 'op_types': ['Linear']}]
kwargs = {'ranking_criterion': criterion, 'head_hidden_dim': 64}
if global_sort:
kwargs['global_sort'] = True
else:
kwargs['global_sort'] = False
if use_graph:
attention_name_groups = list(zip(['embedding.layers.{}.self_attn.q_proj'.format(i) for i in range(6)],
['embedding.layers.{}.self_attn.k_proj'.format(i) for i in range(6)],
['embedding.layers.{}.self_attn.v_proj'.format(i) for i in range(6)],
['embedding.layers.{}.self_attn.output_proj'.format(i) for i in range(6)]))
kwargs['attention_name_groups'] = attention_name_groups
else:
dummy_input = (torch.randint(0, 100, (10, 32)).to(device), torch.ones(32).to(device))
kwargs['dummy_input'] = dummy_input
if iterative:
kwargs['num_iterations'] = 2
kwargs['epochs_per_iteration'] = 1
n_layers = 6
n_heads = 8
hidden_dim = 512
model = Model(n_layers, hidden_dim, n_heads)
model.to(device)
kwargs['optimizer'] = torch.optim.SGD(model.parameters(), lr=0.001)
def trainer(model, optimizer, criterion, epoch):
return train(model, None, criterion, optimizer)
kwargs['trainer'] = trainer
kwargs['criterion'] = nn.BCELoss()
def forward_runner(model):
return dry_run(model)
kwargs['forward_runner'] = forward_runner
# create pruner and call compress()
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
# test model and mask export
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', device=device)
dummy_input = (torch.randint(0, 100, (10, 32)).to(device), torch.ones(32).to(device))
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth',
dummy_input=dummy_input, opset_version=10)
# validate sparsity
if not global_sort:
for wrapper in pruner.modules_wrapper:
validate_sparsity(wrapper, wrapper.config['sparsity'])
class PrunerTestCase(TestCase):
def test_head_pruner(self):
for criterion in ["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"]:
for global_sort in [False, True]:
for use_graph in [False, True]:
for iterative in [False, True]:
head_pruner_tests(criterion, global_sort, use_graph, iterative)
file_paths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv',
'./search_result.json']
for f in file_paths:
if os.path.exists(f):
os.remove(f)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment