Unverified Commit e5b4bf1a authored by Bill Wu's avatar Bill Wu Committed by GitHub
Browse files

Fix transformer pruning example (#4002)

parent dfd853d3
......@@ -737,7 +737,7 @@ Transformer Head Pruner is a tool designed for pruning attention heads from the
Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together.
Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field.
Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field. For instance, the `Huggingface transformers <https://huggingface.co/transformers/index.html>`_ are supported, but ``torch.nn.Transformer`` is not.
The pruner implements the following algorithm:
......@@ -756,11 +756,9 @@ Currently, the following head sorting criteria are supported:
* "l2_activation": rank heads by the L2-norm of their attention computation output.
* "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. As a reminder, we found that if global sorting is used, it is usually helpful to use an iterative pruning scheme, interleaving pruning with intermediate finetuning, since global sorting often results in non-uniform sparsity distributions, which makes the model more susceptible to forgetting.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also note that weights belonging to the same layer must have the same sparsity.
In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input instead and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also, note that we require the weights belonging to the same layer to have the same sparsity.
Usage
^^^^^
......@@ -786,6 +784,7 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
["encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
["encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))
kwargs = {"ranking_criterion": "l1_weight",
"global_sort": False,
"num_iterations": 1,
......@@ -796,20 +795,26 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
"optimizer": optimizer,
"forward_runner": forward_runner
}
config_list = [{
"sparsity": 0.5,
config_list = [{
"sparsity": 0.5,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers
},
{
"sparsity": 0.25,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
}
]
}, {
"sparsity": 0.25,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
}]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
In addition to this usage guide, we provide a more detailed example of pruning BERT (Huggingface implementation) for transfer learning on the tasks from the `GLUE benchmark <https://gluebenchmark.com/>`_. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`. To run the example, first make sure that you install the package ``transformers`` and ``datasets``. Then, you may start by running the following command:
.. code-block:: bash
./run.sh gpu_id glue_task
By default, the code will download a pretrained BERT language model, and then finetune for several epochs on the downstream GLUE task. Then, the ``TransformerHeadPruner`` will be used to prune out heads from each layer by a certain criterion (by default, the code lets the pruner uses magnitude ranking, and prunes out 50% of the heads in each layer in an one-shot manner). Finally, the pruned model will be finetuned in the downstream task for several epochs. You can check the details of pruning from the logs printed out by the example. You can also experiment with different pruning settings by changing the parameters in ``run.sh``, or directly changing the ``config_list`` in ``transformer_pruning.py``.
User configuration for Transformer Head Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
......@@ -3,33 +3,24 @@
import argparse
import logging
import math
import os
import random
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import nni
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
import datasets
from datasets import load_dataset, load_metric
import transformers
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
PretrainedConfig,
default_data_collator,
get_scheduler,
)
......@@ -38,7 +29,8 @@ logger = logging.getLogger("bert_pruning_example")
def parse_args():
parser = argparse.ArgumentParser(description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")
parser = argparse.ArgumentParser(
description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")
parser.add_argument("--model_name", type=str, required=True,
help="Pretrained model architecture.")
......@@ -53,7 +45,8 @@ def parse_args():
help="Rank the heads globally and prune the heads with lowest scores. If set to False, the "
"heads are only ranked within one layer")
parser.add_argument("--ranking_criterion", type=str, default="l1_weight",
choices=["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"],
choices=["l1_weight", "l2_weight",
"l1_activation", "l2_activation", "taylorfo"],
help="Criterion by which the attention heads are ranked.")
parser.add_argument("--num_iterations", type=int, default=1,
help="Number of pruning iterations (1 for one-shot pruning).")
......@@ -93,7 +86,8 @@ def get_raw_dataset(task_name):
"""
raw_dataset = load_dataset("glue", task_name)
is_regression = task_name == "stsb"
num_labels = 1 if is_regression else len(raw_dataset["train"].features["label"].names)
num_labels = 1 if is_regression else len(
raw_dataset["train"].features["label"].names)
return raw_dataset, is_regression, num_labels
......@@ -105,29 +99,32 @@ def preprocess(args, tokenizer, raw_dataset):
assert args.task_name is not None
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[args.task_name]
def tokenize(data):
texts = (
(data[sentence1_key],) if sentence2_key is None else (data[sentence1_key], data[sentence2_key])
(data[sentence1_key],) if sentence2_key is None else (
data[sentence1_key], data[sentence2_key])
)
result = tokenizer(*texts, padding=False, max_length=args.max_length, truncation=True)
result = tokenizer(*texts, padding=False,
max_length=args.max_length, truncation=True)
if "label" in data:
result["labels"] = data["label"]
return result
processed_datasets = raw_dataset.map(tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
processed_datasets = raw_dataset.map(
tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
return processed_datasets
......@@ -168,7 +165,8 @@ def train_model(args, model, is_regression, train_dataloader, eval_dataloader, o
for field in batch.keys():
batch[field] = batch[field].to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions = outputs.logits.argmax(dim=-1) if not is_regression \
else outputs.logits.squeeze()
metric.add_batch(predictions=predictions, references=batch["labels"])
eval_metric = metric.compute()
......@@ -183,7 +181,7 @@ def trainer_helper(model, train_dataloader, optimizer, device):
"""
logger.info("Training for 1 epoch...")
progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True)
train_epoch = 1
for epoch in range(train_epoch):
for step, batch in enumerate(train_dataloader):
......@@ -213,7 +211,7 @@ def forward_runner_helper(model, train_dataloader, device):
_ = model(**batch)
# note: no loss.backward or optimizer.step() is performed here
progress_bar.update(1)
def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator):
"""
......@@ -248,15 +246,18 @@ def main():
transformers.utils.logging.set_verbosity_info()
# Load dataset and tokenizer, and then preprocess the dataset
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
processed_datasets = preprocess(args, tokenizer, raw_dataset)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
eval_dataset = processed_datasets["validation_matched" if args.task_name ==
"mnli" else "validation"]
# Load pretrained model
config = AutoConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
config = AutoConfig.from_pretrained(
args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name, config=config)
model.to(device)
#########################################################################
......@@ -269,9 +270,10 @@ def main():
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps)
metric = load_metric("glue", args.task_name)
logger.info("================= Finetuning before pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt")
......@@ -316,13 +318,11 @@ def main():
"sparsity": args.sparsity,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer]
},
{
"sparsity": args.sparsity / 2,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer]
}
]
}, {
"sparsity": args.sparsity / 2,
"op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[6:] for x in layer]
}]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()
......@@ -360,7 +360,7 @@ def main():
# After pruning, finetune again on the target task
# Get the metric function
metric = load_metric("glue", args.task_name)
# re-initialize the optimizer and the scheduler
optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset,
eval_dataset)
......@@ -368,13 +368,16 @@ def main():
num_training_steps=train_steps)
logger.info("================= Finetuning after Pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device)
train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_after_pruning.pt")
torch.save(model.state_dict(), args.output_dir +
"/model_after_pruning.pt")
if args.task_name == "mnli":
final_eval_for_mnli(args, model, processed_datasets, metric, data_collator)
final_eval_for_mnli(args, model, processed_datasets,
metric, data_collator)
flops, params, results = count_flops_params(model, dummy_input)
print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment