Commit 9ca788b2 authored by Rémi Louf's avatar Rémi Louf
Browse files

merge the two Bert layers classes

parent edfc8f82
...@@ -318,42 +318,30 @@ class BertOutput(nn.Module): ...@@ -318,42 +318,30 @@ class BertOutput(nn.Module):
return hidden_states return hidden_states
class BertEncoderLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertEncoderLayer, self).__init__() super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class BertDecoderLayer(nn.Module):
def __init__(self, config):
super(BertDecoderLayer, self).__init__()
self.self_attention = BertAttention(config) self.self_attention = BertAttention(config)
self.attention = BertDecoderAttention(config) if config.get('is_decoder', False):
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
self_attention_output = self_attention_outputs[0] self_attention_output = self_attention_outputs[0]
attention_outputs = self.attention(query=self_attention_output,
key=encoder_outputs, attention_outputs = self_attention_outputs
value=encoder_outputs, if encoder_hidden_state:
attention_mask=attention_mask, try:
head_mask=head_mask) attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae:
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -362,7 +350,7 @@ class BertEncoder(nn.Module): ...@@ -362,7 +350,7 @@ class BertEncoder(nn.Module):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = () all_hidden_states = ()
...@@ -392,9 +380,10 @@ class BertEncoder(nn.Module): ...@@ -392,9 +380,10 @@ class BertEncoder(nn.Module):
class BertDecoder(nn.Module): class BertDecoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertDecoder, self).__init__() super(BertDecoder, self).__init__()
config["is_decoder"] = True
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
all_hidden_states = () all_hidden_states = ()
...@@ -403,7 +392,10 @@ class BertDecoder(nn.Module): ...@@ -403,7 +392,10 @@ class BertDecoder(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) layer_outputs = layer_module(hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_state=encoder_outputs)
if self.output_attentions: if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment