"docs/source/vscode:/vscode.git/clone" did not exist on "66add161dcaa91f4e60c8e3224ed297ab72e7b0f"
Commit 28d0ba35 authored by Rémi Louf's avatar Rémi Louf
Browse files

only init encoder_attention_mask if stack is decoder

We currently initialize `encoder_attention_mask` when it is `None`,
whether the stack is that of an encoder or a decoder. Since this
may lead to bugs that are difficult to tracks down, I added a condition
that assesses whether the current stack is a decoder.
parent 1c542df7
...@@ -656,7 +656,7 @@ class BertModel(BertPreTrainedModel): ...@@ -656,7 +656,7 @@ class BertModel(BertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None: if self.config.is_decoder and encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device) encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment