Commit bdb4409e authored by thomwolf's avatar thomwolf Committed by LysandreJik
Browse files

updated pruning logic with sets - Bert and GPT-2

parent 0c8e823b
...@@ -337,26 +337,30 @@ class BertAttention(nn.Module): ...@@ -337,26 +337,30 @@ class BertAttention(nn.Module):
super(BertAttention, self).__init__() super(BertAttention, self).__init__()
self.self = BertSelfAttention(config) self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
self.pruned_heads = [] self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads: for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads))) # Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index) self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index) self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index) self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads.extend(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_tensor, attention_mask, head_mask=None): def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask) self_outputs = self.self(input_tensor, attention_mask, head_mask)
...@@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs): def _init_weights(self, module):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs) """ Initialize the weights """
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel): ...@@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
if hasattr(config, "pruned_heads"): self.init_weights()
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings old_embeddings = self.embeddings.word_embeddings
...@@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config) self.cls = BertPreTrainingHeads(config)
self.apply(self.init_weights) self.init_weights()
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
self.apply(self.init_weights) self.init_weights()
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
...@@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
...@@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
...@@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
...@@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.bert = BertModel(config) self.bert = BertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, position_ids=None, head_mask=None): end_positions=None, position_ids=None, head_mask=None):
......
...@@ -233,25 +233,29 @@ class Attention(nn.Module): ...@@ -233,25 +233,29 @@ class Attention(nn.Module):
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = [] self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_head, self.split_size // self.n_head) mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads: for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads))) # Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers # Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params # Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads) self.n_head = self.n_head - len(heads)
self.pruned_heads.extend(heads) self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, head_mask=None): def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k) w = torch.matmul(q, k)
...@@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs) super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
...@@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
if hasattr(config, "pruned_heads"): self.init_weights()
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.h[int(layer)].attn.n_head == config.n_head:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
...@@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self.init_weights) self.init_weights()
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
self.apply(self.init_weights) self.init_weights()
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
......
...@@ -202,8 +202,7 @@ class PretrainedConfig(object): ...@@ -202,8 +202,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file) config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'): if hasattr(config, 'pruned_heads'):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed # Update config with kwargs if needed
to_remove = [] to_remove = []
...@@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module): ...@@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module):
new_embeddings.to(old_embeddings.weight.device) new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self.init_weights(new_embeddings) self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights # Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
...@@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module): ...@@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module):
return model_embeds return model_embeds
def init_weights(self):
""" Initialize and prunes weights if needed. """
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """ Prunes heads of the base model.
Arguments: Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
""" """
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
to_be_pruned = {} # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
if int(layer) not in self.config.pruned_heads: union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[int(layer)] = heads self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
to_be_pruned[int(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head)
else:
to_be_pruned[int(layer)] = [head]
else:
logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned) base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it """ Save a model and its configuration file to a directory, so that it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment