"vscode:/vscode.git/clone" did not exist on "69c85d976ad1cab364f904e0b8de2885d17267ab"
Commit 06a6cb6f authored by Tom Hosking's avatar Tom Hosking
Browse files

Refactor BartModel so that input checks are handled within BartEncoder and BartDecoder

parent 30624f70
File added
...@@ -271,6 +271,12 @@ class BartEncoder(nn.Module): ...@@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
- **all_attentions** (List[Tensor]): Attention weights for each layer. - **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos x = inputs_embeds + embed_pos
...@@ -448,6 +454,13 @@ class BartDecoder(nn.Module): ...@@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
- hidden states - hidden states
- attentions - attentions
""" """
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode) positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
...@@ -823,11 +836,6 @@ class BartModel(PretrainedBartModel): ...@@ -823,11 +836,6 @@ class BartModel(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
): ):
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
# make masks if user doesn't supply # make masks if user doesn't supply
if not self.decoder.generation_mode: if not self.decoder.generation_mode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment