"test/ut/experiment/vscode:/vscode.git/clone" did not exist on "d5857823f73f257a5a54088e118de447b7e9c62f"
Commit 34d706a0 authored by thomwolf's avatar thomwolf
Browse files

pruning in bertology

parent dc8e0019
......@@ -724,7 +724,7 @@ We detail them here. This model takes as *inputs*:
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
This model *outputs* a tuple composed of:
......@@ -859,7 +859,7 @@ We detail them here. This model takes as *inputs*:
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
This model *outputs*:
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
......@@ -960,7 +960,7 @@ We detail them here. This model takes as *inputs*:
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
- `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. It's a mask to be used to nullify some heads of the transformer. 0.0 => head is fully masked, 1.0 => head is not masked.
This model *outputs*:
- `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings) as torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
......
......@@ -2,6 +2,7 @@
import os
import argparse
import logging
from datetime import timedelta, datetime
from tqdm import tqdm
import numpy as np
......@@ -35,39 +36,57 @@ def print_2d_tensor(tensor):
for row in range(len(tensor)):
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
def compute_heads_importance(args, model, eval_dataloader):
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
""" Example on how to use model outputs to compute:
- head attention entropy (activated by setting output_attentions=True when we created the model
- head importance scores according to http://arxiv.org/abs/1905.10650
(activated by setting keep_multihead_output=True when we created the model)
"""
# Prepare our tensors
n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
preds = None
labels = None
tot_tokens = 0.0
for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
# Do a forward pass
all_attentions, logits = model(input_ids, segment_ids, input_mask)
# Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below)
all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)
if compute_entropy:
# Update head attention entropy
for layer, attn in enumerate(all_attentions):
masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
if compute_importance:
# Update head importance scores with regards to our loss
# First backpropagate to populate the gradients
if output_mode == "classification":
# First, backpropagate to populate the gradients
if args.output_mode == "classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
elif args.output_mode == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), label_ids.view(-1))
loss.backward()
# Second compute importance scores according to http://arxiv.org/abs/1905.10650
# Second, compute importance scores according to http://arxiv.org/abs/1905.10650
multihead_outputs = model.bert.get_multihead_outputs()
for layer, mh_layer_output in enumerate(multihead_outputs):
dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
head_importance[layer] += dot.abs().sum(-1).sum(0).detach()
# Also store our logits/labels if we want to compute metrics afterwards
if preds is None:
preds = logits.detach().cpu().numpy()
labels = label_ids.detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
tot_tokens += input_mask.float().detach().sum().data
# Normalize
......@@ -76,7 +95,7 @@ def compute_heads_importance(args, model, eval_dataloader):
if args.normalize_importance:
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
return attn_entropy, head_importance
return attn_entropy, head_importance, preds, labels
def run_model():
parser = argparse.ArgumentParser()
......@@ -89,8 +108,11 @@ def run_model():
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")
parser.add_argument("--try_pruning", action='store_true', help="Whether to try to prune head until a threshold of accuracy.")
parser.add_argument("--pruning_threshold", default=0.9, type=float, help="Pruning threshold of accuracy.")
parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
"(stop masking when metric < threshold * original metric value).")
parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.")
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
......@@ -125,9 +147,9 @@ def run_model():
# Prepare GLUE task
task_name = args.task_name.lower()
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
args.output_mode = output_modes[task_name]
args.num_labels = len(label_list)
# Prepare output directory
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
......@@ -145,7 +167,7 @@ def run_model():
# keep_multihead_output => will store gradient of attention head outputs for head importance computation
# see: http://arxiv.org/abs/1905.10650
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
num_labels=num_labels,
num_labels=args.num_labels,
output_attentions=True,
keep_multihead_output=True)
if args.local_rank == 0:
......@@ -162,7 +184,7 @@ def run_model():
try:
eval_features = torch.load(cached_eval_features_file)
except:
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode)
if args.local_rank in [-1, 0]:
logger.info("Saving eval features to cache file %s", cached_eval_features_file)
torch.save(eval_features, cached_eval_features_file)
......@@ -170,7 +192,7 @@ def run_model():
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if output_mode == "classification" else torch.float)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.data_subset > 0:
......@@ -183,16 +205,8 @@ def run_model():
print(args)
torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))
# To showcase some BERTology methods, we will compute:
# - the average entropy of each head over the dev set
# - the importance score of each head over the dev set as explained in http://arxiv.org/abs/1905.10650
n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
tot_tokens = 0.0
# Compute head entropy and importance score
attn_entropy, head_importance = compute_heads_importance(args, model, eval_dataloader)
attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
# Print/save matrices
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
......@@ -203,14 +217,67 @@ def run_model():
logger.info("Head importance scores")
print_2d_tensor(head_importance)
logger.info("Head ranked by importance scores")
head_ranks = torch.zeros(n_layers * n_heads, dtype=torch.long, device=args.device)
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
print_2d_tensor(head_ranks.view_as(head_importance))
# Do pruning if we want to
if args.try_pruning and args.pruning_threshold > 0.0 and args.pruning_threshold < 1.0:
head_ranks = head_ranks.view_as(head_importance)
print_2d_tensor(head_ranks)
# Do masking if we want to
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Pruning: original score: %f", original_score)
new_head_mask = torch.ones_like(head_importance)
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
current_score = original_score
while current_score >= original_score * args.masking_threshold:
head_mask = new_head_mask
# heads from most important to least
heads_to_mask = head_importance.view(-1).sort(descending=True)[1]
# keep only not-masked heads
heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]
if len(heads_to_mask) <= num_to_mask:
break
# mask heads
heads_to_mask = heads_to_mask[-num_to_mask:]
new_head_mask = head_mask.view(-1)
new_head_mask[heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_importance)
print_2d_tensor(new_head_mask)
# Compute metric and head importance again
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)
# Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights
before_time = datetime.now()
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
compute_entropy=False, compute_importance=False, head_mask=head_mask)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
original_time = datetime.now() - before_time
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
model.bert.prune_heads(heads_to_prune)
before_time = datetime.now()
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
compute_entropy=False, compute_importance=False, head_mask=None)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
new_time = datetime.now() - before_time
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
if __name__ == '__main__':
run_model()
......@@ -308,7 +308,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
logits = model(input_ids, segment_ids, input_mask)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
......@@ -422,7 +422,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
......@@ -503,7 +503,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None)
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
......
......@@ -408,6 +408,8 @@ class BertAttention(nn.Module):
self.output = BertSelfOutput(config)
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads:
mask[head] = 0
......@@ -858,9 +860,10 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape num_hidden_layers x batch x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
......@@ -868,7 +871,6 @@ class BertModel(BertPreTrainedModel):
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.num_hidden_layers
......
......@@ -264,6 +264,8 @@ class Attention(nn.Module):
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads:
mask[head] = 0
......@@ -714,7 +716,7 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
......@@ -724,7 +726,6 @@ class GPT2Model(GPT2PreTrainedModel):
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.n_layer
......
......@@ -274,6 +274,8 @@ class Attention(nn.Module):
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads:
mask[head] = 0
......@@ -710,7 +712,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
......@@ -720,7 +722,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.n_layer
......
......@@ -215,9 +215,9 @@ class GPT2ModelTest(unittest.TestCase):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
......
......@@ -188,9 +188,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
head_mask = torch.ones(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
......
......@@ -305,9 +305,9 @@ class BertModelTest(unittest.TestCase):
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment