"docs/vscode:/vscode.git/clone" did not exist on "7748cbbe7d85c85e0328a5bc92a70979ef987060"
Unverified Commit 6ffe03a0 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #3137 from tomhosking/bart-refactor

Refactor BartModel so that input checks are handled within enc/dec
parents 3e5da38d 31acb8dc
......@@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
......@@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
- hidden states
- attentions
"""
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
# embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
......@@ -808,11 +821,6 @@ class BartModel(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
):
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
# make masks if user doesn't supply
if not self.decoder.generation_mode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment