Commit 51261167 authored by Rémi Louf's avatar Rémi Louf
Browse files

prune both attention and self-attention heads

parent 17177e73
...@@ -633,7 +633,7 @@ class BertModel(BertPreTrainedModel): ...@@ -633,7 +633,7 @@ class BertModel(BertPreTrainedModel):
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].self_attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
...@@ -736,7 +736,8 @@ class BertDecoderModel(BertPreTrainedModel): ...@@ -736,7 +736,8 @@ class BertDecoderModel(BertPreTrainedModel):
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].self_attention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment