Commit bdb4409e authored by thomwolf's avatar thomwolf Committed by LysandreJik
Browse files

updated pruning logic with sets - Bert and GPT-2

parent 0c8e823b
......@@ -337,26 +337,30 @@ class BertAttention(nn.Module):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = []
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads.extend(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
......@@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......@@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
......@@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
......@@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
......@@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
position_ids=None, head_mask=None):
......@@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
......@@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
......@@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
......@@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.bert = BertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, position_ids=None, head_mask=None):
......
......@@ -233,25 +233,29 @@ class Attention(nn.Module):
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = []
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads.extend(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k)
......@@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
......@@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.h[int(layer)].attn.n_head == config.n_head:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
......@@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
......@@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
......
......@@ -202,8 +202,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
to_remove = []
......@@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module):
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self.init_weights(new_embeddings)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
......@@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module):
return model_embeds
def init_weights(self):
""" Initialize and prunes weights if needed. """
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
to_be_pruned = {}
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
if int(layer) not in self.config.pruned_heads:
self.config.pruned_heads[int(layer)] = heads
to_be_pruned[int(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head)
else:
to_be_pruned[int(layer)] = [head]
else:
logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned)
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment