Commit a0dcefa3 authored by Rémi Louf's avatar Rémi Louf
Browse files

generalize BertSelfAttention to take separate query, key, value

There is currently no way to specify the quey, key and value separately
in the Attention module. However, the decoder's "encoder-decoder
attention" layers take the decoder's last output as a query, the
encoder's states as key and value. We thus modify the existing code so
query, key and value can be added separately.

This obviously poses some naming conventions; `BertSelfAttention` is not
a self-attention module anymore. The way the residual is forwarded is
now awkard, etc. We will need to do some refacto once the decoder is
fully implemented.
parent 31adbb24
...@@ -198,10 +198,10 @@ class BertSelfAttention(nn.Module): ...@@ -198,10 +198,10 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, query, key, value, attention_mask=None, head_mask=None):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(query)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(key)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(value)
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer) key_layer = self.transpose_for_scores(mixed_key_layer)
...@@ -279,9 +279,12 @@ class BertAttention(nn.Module): ...@@ -279,9 +279,12 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_tensor, attention_mask=None, head_mask=None): def forward(self, query_tensor, key_tensor, value_tensor, attention_mask=None, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask) self_outputs = self.self(query_tensor, key_tensor, value_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor) # in encoder-decoder attention we use the output of the previous decoder stage as the query
# in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput.
# This shows the limits of the current code architecture, which may benefit from some refactoring.
attention_output = self.output(self_outputs[0], query_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -323,7 +326,11 @@ class BertEncoderLayer(nn.Module): ...@@ -323,7 +326,11 @@ class BertEncoderLayer(nn.Module):
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_outputs = self.attention(query_tensor=hidden_states,
key_tensor=hidden_states,
value_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask)
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
...@@ -333,6 +340,7 @@ class BertEncoderLayer(nn.Module): ...@@ -333,6 +340,7 @@ class BertEncoderLayer(nn.Module):
class BertDecoderLayer(nn.Module): class BertDecoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertDecoderLayer, self).__init__()
raise NotImplementedError raise NotImplementedError
def forward(self, hidden_state, encoder_output): def forward(self, hidden_state, encoder_output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment