Unverified Commit e5b4bf1a authored by Bill Wu's avatar Bill Wu Committed by GitHub
Browse files

Fix transformer pruning example (#4002)

parent dfd853d3
...@@ -737,7 +737,7 @@ Transformer Head Pruner is a tool designed for pruning attention heads from the ...@@ -737,7 +737,7 @@ Transformer Head Pruner is a tool designed for pruning attention heads from the
Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together. Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together.
Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field. Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field. For instance, the `Huggingface transformers <https://huggingface.co/transformers/index.html>`_ are supported, but ``torch.nn.Transformer`` is not.
The pruner implements the following algorithm: The pruner implements the following algorithm:
...@@ -756,11 +756,9 @@ Currently, the following head sorting criteria are supported: ...@@ -756,11 +756,9 @@ Currently, the following head sorting criteria are supported:
* "l2_activation": rank heads by the L2-norm of their attention computation output. * "l2_activation": rank heads by the L2-norm of their attention computation output.
* "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__. * "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__.
We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. As a reminder, we found that if global sorting is used, it is usually helpful to use an iterative pruning scheme, interleaving pruning with intermediate finetuning, since global sorting often results in non-uniform sparsity distributions, which makes the model more susceptible to forgetting.
In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also note that weights belonging to the same layer must have the same sparsity. In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input instead and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also, note that we require the weights belonging to the same layer to have the same sparsity.
In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`.
Usage Usage
^^^^^ ^^^^^
...@@ -786,6 +784,7 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f ...@@ -786,6 +784,7 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
["encoder.layer.{}.attention.self.key".format(i) for i in range(12)], ["encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
["encoder.layer.{}.attention.self.value".format(i) for i in range(12)], ["encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)])) ["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))
kwargs = {"ranking_criterion": "l1_weight", kwargs = {"ranking_criterion": "l1_weight",
"global_sort": False, "global_sort": False,
"num_iterations": 1, "num_iterations": 1,
...@@ -796,20 +795,26 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f ...@@ -796,20 +795,26 @@ Suppose we want to prune a BERT with Huggingface implementation, which has the f
"optimizer": optimizer, "optimizer": optimizer,
"forward_runner": forward_runner "forward_runner": forward_runner
} }
config_list = [{ config_list = [{
"sparsity": 0.5, "sparsity": 0.5,
"op_types": ["Linear"], "op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers "op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers
}, }, {
{ "sparsity": 0.25,
"sparsity": 0.25, "op_types": ["Linear"],
"op_types": ["Linear"], "op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers
"op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers }]
}
]
pruner = TransformerHeadPruner(model, config_list, **kwargs) pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress() pruner.compress()
In addition to this usage guide, we provide a more detailed example of pruning BERT (Huggingface implementation) for transfer learning on the tasks from the `GLUE benchmark <https://gluebenchmark.com/>`_. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`. To run the example, first make sure that you install the package ``transformers`` and ``datasets``. Then, you may start by running the following command:
.. code-block:: bash
./run.sh gpu_id glue_task
By default, the code will download a pretrained BERT language model, and then finetune for several epochs on the downstream GLUE task. Then, the ``TransformerHeadPruner`` will be used to prune out heads from each layer by a certain criterion (by default, the code lets the pruner uses magnitude ranking, and prunes out 50% of the heads in each layer in an one-shot manner). Finally, the pruned model will be finetuned in the downstream task for several epochs. You can check the details of pruning from the logs printed out by the example. You can also experiment with different pruning settings by changing the parameters in ``run.sh``, or directly changing the ``config_list`` in ``transformer_pruning.py``.
User configuration for Transformer Head Pruner User configuration for Transformer Head Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -3,33 +3,24 @@ ...@@ -3,33 +3,24 @@
import argparse import argparse
import logging import logging
import math
import os import os
import random
import torch import torch
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import nni
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
import datasets import datasets
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
import transformers import transformers
from transformers import ( from transformers import (
AdamW, AdamW,
AutoConfig, AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
DataCollatorWithPadding, DataCollatorWithPadding,
PretrainedConfig,
default_data_collator,
get_scheduler, get_scheduler,
) )
...@@ -38,7 +29,8 @@ logger = logging.getLogger("bert_pruning_example") ...@@ -38,7 +29,8 @@ logger = logging.getLogger("bert_pruning_example")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Example: prune a Huggingface transformer and finetune on GLUE tasks.") parser = argparse.ArgumentParser(
description="Example: prune a Huggingface transformer and finetune on GLUE tasks.")
parser.add_argument("--model_name", type=str, required=True, parser.add_argument("--model_name", type=str, required=True,
help="Pretrained model architecture.") help="Pretrained model architecture.")
...@@ -53,7 +45,8 @@ def parse_args(): ...@@ -53,7 +45,8 @@ def parse_args():
help="Rank the heads globally and prune the heads with lowest scores. If set to False, the " help="Rank the heads globally and prune the heads with lowest scores. If set to False, the "
"heads are only ranked within one layer") "heads are only ranked within one layer")
parser.add_argument("--ranking_criterion", type=str, default="l1_weight", parser.add_argument("--ranking_criterion", type=str, default="l1_weight",
choices=["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"], choices=["l1_weight", "l2_weight",
"l1_activation", "l2_activation", "taylorfo"],
help="Criterion by which the attention heads are ranked.") help="Criterion by which the attention heads are ranked.")
parser.add_argument("--num_iterations", type=int, default=1, parser.add_argument("--num_iterations", type=int, default=1,
help="Number of pruning iterations (1 for one-shot pruning).") help="Number of pruning iterations (1 for one-shot pruning).")
...@@ -93,7 +86,8 @@ def get_raw_dataset(task_name): ...@@ -93,7 +86,8 @@ def get_raw_dataset(task_name):
""" """
raw_dataset = load_dataset("glue", task_name) raw_dataset = load_dataset("glue", task_name)
is_regression = task_name == "stsb" is_regression = task_name == "stsb"
num_labels = 1 if is_regression else len(raw_dataset["train"].features["label"].names) num_labels = 1 if is_regression else len(
raw_dataset["train"].features["label"].names)
return raw_dataset, is_regression, num_labels return raw_dataset, is_regression, num_labels
...@@ -105,29 +99,32 @@ def preprocess(args, tokenizer, raw_dataset): ...@@ -105,29 +99,32 @@ def preprocess(args, tokenizer, raw_dataset):
assert args.task_name is not None assert args.task_name is not None
task_to_keys = { task_to_keys = {
"cola": ("sentence", None), "cola": ("sentence", None),
"mnli": ("premise", "hypothesis"), "mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"), "mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"), "qnli": ("question", "sentence"),
"qqp": ("question1", "question2"), "qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"), "rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None), "sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"), "stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"), "wnli": ("sentence1", "sentence2"),
} }
sentence1_key, sentence2_key = task_to_keys[args.task_name] sentence1_key, sentence2_key = task_to_keys[args.task_name]
def tokenize(data): def tokenize(data):
texts = ( texts = (
(data[sentence1_key],) if sentence2_key is None else (data[sentence1_key], data[sentence2_key]) (data[sentence1_key],) if sentence2_key is None else (
data[sentence1_key], data[sentence2_key])
) )
result = tokenizer(*texts, padding=False, max_length=args.max_length, truncation=True) result = tokenizer(*texts, padding=False,
max_length=args.max_length, truncation=True)
if "label" in data: if "label" in data:
result["labels"] = data["label"] result["labels"] = data["label"]
return result return result
processed_datasets = raw_dataset.map(tokenize, batched=True, remove_columns=raw_dataset["train"].column_names) processed_datasets = raw_dataset.map(
tokenize, batched=True, remove_columns=raw_dataset["train"].column_names)
return processed_datasets return processed_datasets
...@@ -168,7 +165,8 @@ def train_model(args, model, is_regression, train_dataloader, eval_dataloader, o ...@@ -168,7 +165,8 @@ def train_model(args, model, is_regression, train_dataloader, eval_dataloader, o
for field in batch.keys(): for field in batch.keys():
batch[field] = batch[field].to(device) batch[field] = batch[field].to(device)
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression \
else outputs.logits.squeeze()
metric.add_batch(predictions=predictions, references=batch["labels"]) metric.add_batch(predictions=predictions, references=batch["labels"])
eval_metric = metric.compute() eval_metric = metric.compute()
...@@ -183,7 +181,7 @@ def trainer_helper(model, train_dataloader, optimizer, device): ...@@ -183,7 +181,7 @@ def trainer_helper(model, train_dataloader, optimizer, device):
""" """
logger.info("Training for 1 epoch...") logger.info("Training for 1 epoch...")
progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True) progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True)
train_epoch = 1 train_epoch = 1
for epoch in range(train_epoch): for epoch in range(train_epoch):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
...@@ -213,7 +211,7 @@ def forward_runner_helper(model, train_dataloader, device): ...@@ -213,7 +211,7 @@ def forward_runner_helper(model, train_dataloader, device):
_ = model(**batch) _ = model(**batch)
# note: no loss.backward or optimizer.step() is performed here # note: no loss.backward or optimizer.step() is performed here
progress_bar.update(1) progress_bar.update(1)
def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator): def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator):
""" """
...@@ -248,15 +246,18 @@ def main(): ...@@ -248,15 +246,18 @@ def main():
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
# Load dataset and tokenizer, and then preprocess the dataset # Load dataset and tokenizer, and then preprocess the dataset
raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name) raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
processed_datasets = preprocess(args, tokenizer, raw_dataset) processed_datasets = preprocess(args, tokenizer, raw_dataset)
train_dataset = processed_datasets["train"] train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] eval_dataset = processed_datasets["validation_matched" if args.task_name ==
"mnli" else "validation"]
# Load pretrained model # Load pretrained model
config = AutoConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name) config = AutoConfig.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name, config=config)
model.to(device) model.to(device)
######################################################################### #########################################################################
...@@ -269,9 +270,10 @@ def main(): ...@@ -269,9 +270,10 @@ def main():
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps,
num_training_steps=train_steps) num_training_steps=train_steps)
metric = load_metric("glue", args.task_name) metric = load_metric("glue", args.task_name)
logger.info("================= Finetuning before pruning =================") logger.info("================= Finetuning before pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device) train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None: if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt") torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt")
...@@ -316,13 +318,11 @@ def main(): ...@@ -316,13 +318,11 @@ def main():
"sparsity": args.sparsity, "sparsity": args.sparsity,
"op_types": ["Linear"], "op_types": ["Linear"],
"op_names": [x for layer in attention_name_groups[:6] for x in layer] "op_names": [x for layer in attention_name_groups[:6] for x in layer]
}, }, {
{ "sparsity": args.sparsity / 2,
"sparsity": args.sparsity / 2, "op_types": ["Linear"],
"op_types": ["Linear"], "op_names": [x for layer in attention_name_groups[6:] for x in layer]
"op_names": [x for layer in attention_name_groups[6:] for x in layer] }]
}
]
pruner = TransformerHeadPruner(model, config_list, **kwargs) pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress() pruner.compress()
...@@ -360,7 +360,7 @@ def main(): ...@@ -360,7 +360,7 @@ def main():
# After pruning, finetune again on the target task # After pruning, finetune again on the target task
# Get the metric function # Get the metric function
metric = load_metric("glue", args.task_name) metric = load_metric("glue", args.task_name)
# re-initialize the optimizer and the scheduler # re-initialize the optimizer and the scheduler
optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset, optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset,
eval_dataset) eval_dataset)
...@@ -368,13 +368,16 @@ def main(): ...@@ -368,13 +368,16 @@ def main():
num_training_steps=train_steps) num_training_steps=train_steps)
logger.info("================= Finetuning after Pruning =================") logger.info("================= Finetuning after Pruning =================")
train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device) train_model(args, model, is_regression, train_dataloader,
eval_dataloader, optimizer, lr_scheduler, metric, device)
if args.output_dir is not None: if args.output_dir is not None:
torch.save(model.state_dict(), args.output_dir + "/model_after_pruning.pt") torch.save(model.state_dict(), args.output_dir +
"/model_after_pruning.pt")
if args.task_name == "mnli": if args.task_name == "mnli":
final_eval_for_mnli(args, model, processed_datasets, metric, data_collator) final_eval_for_mnli(args, model, processed_datasets,
metric, data_collator)
flops, params, results = count_flops_params(model, dummy_input) flops, params, results = count_flops_params(model, dummy_input)
print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M") print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment