"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e71f32c0ef326334b2f36e79a2ffa2e1938c64ff"
Unverified Commit 6ffe03a0 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #3137 from tomhosking/bart-refactor

Refactor BartModel so that input checks are handled within enc/dec
parents 3e5da38d 31acb8dc
...@@ -271,6 +271,12 @@ class BartEncoder(nn.Module): ...@@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
- **all_attentions** (List[Tensor]): Attention weights for each layer. - **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos x = inputs_embeds + embed_pos
...@@ -448,6 +454,13 @@ class BartDecoder(nn.Module): ...@@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
- hidden states - hidden states
- attentions - attentions
""" """
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode) positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
...@@ -808,11 +821,6 @@ class BartModel(PretrainedBartModel): ...@@ -808,11 +821,6 @@ class BartModel(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
): ):
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
# make masks if user doesn't supply # make masks if user doesn't supply
if not self.decoder.generation_mode: if not self.decoder.generation_mode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment