"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "273617b86dbe5cd15afb795e994dffc44e09e2df"
Commit d7092d59 authored by Rémi Louf's avatar Rémi Louf
Browse files

rename the attributes in the Bert Layer

Since the preloading of weights relies on the name of the class's
attributes changing the namespace breaks loading pretrained weights on
Bert and all related models. I reverted `self_attention` to `attention`
and us `crossattention` for the decoder instead.
parent 51261167
...@@ -321,25 +321,24 @@ class BertOutput(nn.Module): ...@@ -321,25 +321,24 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.self_attention = BertAttention(config) self.attention = BertAttention(config)
if getattr(config, "is_decoder", False): if getattr(config, "is_decoder", False):
self.attention = BertAttention(config) self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
self_attention_output = self_attention_outputs[0] attention_output = attention_outputs[0]
attention_outputs = self_attention_outputs
if encoder_hidden_state: if encoder_hidden_state:
try: try:
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state) crossattention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae: except AttributeError as ae:
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer") raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
attention_output = attention_outputs[0] crossattention_output = crossattention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(crossattention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -633,7 +632,7 @@ class BertModel(BertPreTrainedModel): ...@@ -633,7 +632,7 @@ class BertModel(BertPreTrainedModel):
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].self_attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
...@@ -737,7 +736,7 @@ class BertDecoderModel(BertPreTrainedModel): ...@@ -737,7 +736,7 @@ class BertDecoderModel(BertPreTrainedModel):
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.decoder.layer[layer].attention.prune_heads(heads) self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].self_attention.prune_heads(heads) self.decoder.layer[layer].crossattention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment