Commit 87d60b6e authored by Rémi Louf's avatar Rémi Louf
Browse files

reword explanation of encoder_attention_mask

parent 638fe7f5
...@@ -201,9 +201,9 @@ class BertSelfAttention(nn.Module): ...@@ -201,9 +201,9 @@ class BertSelfAttention(nn.Module):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# if the attention Module is a encoder-decoder self attention module # If this is instantiated as a cross-attention module, the keys
# they keys & values are given by the encoder; the attention mask # and values come from an encoder; the attention mask needs to be
# needs to be such that there is no atention on the encoder's padding tokens. # such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states) mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states) mixed_value_layer = self.value(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment