Commit 11600edc authored by LysandreJik's avatar LysandreJik
Browse files

Rebase on master + DistilBERT head pruning patch

parent b6992b7b
...@@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_heads, attention_head_size) mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads: for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long() index = torch.arange(len(mask))[mask].long()
...@@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
# Update hyper params # Update hyper params
self.n_heads = self.n_heads - len(heads) self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask = None): def forward(self, query, key, value, mask, head_mask = None):
""" """
...@@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs) super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
...@@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings = Embeddings(config) # Embeddings self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder self.transformer = Transformer(config) # Encoder
self.apply(self.init_weights) self.init_weights()
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings old_embeddings = self.embeddings.word_embeddings
...@@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights) self.init_weights()
self.tie_weights() self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
...@@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.classifier = nn.Linear(config.dim, config.num_labels) self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout) self.dropout = nn.Dropout(config.seq_classif_dropout)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
...@@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
assert config.num_labels == 2 assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout) self.dropout = nn.Dropout(config.qa_dropout)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None): def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment