Unverified Commit 3c6bf899 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

modeling_bart: 3 small cleanups that dont change outputs (#7381)

* Mbart passing

* boom boom

* cleaner assert

* add assert

* Fix tests
parent 9e68d075
...@@ -307,7 +307,7 @@ class BartEncoder(nn.Module): ...@@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm # mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward( def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
...@@ -550,6 +550,10 @@ class BartDecoder(nn.Module): ...@@ -550,6 +550,10 @@ class BartDecoder(nn.Module):
positions = self.embed_positions(input_ids, use_cache=use_cache) positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache: if use_cache:
if input_ids.shape[1] != 1 or past_key_values is None:
# if you make this an AssertionError, test_benchmark breaks.
warnings.warn("pass decoder_past_key_value_states to use_cache")
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any() # assert input_ids.ne(self.padding_idx).any()
...@@ -590,11 +594,12 @@ class BartDecoder(nn.Module): ...@@ -590,11 +594,12 @@ class BartDecoder(nn.Module):
if use_cache: if use_cache:
next_decoder_cache.append(layer_past.copy()) next_decoder_cache.append(layer_past.copy())
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states: if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states) all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
......
...@@ -86,8 +86,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -86,8 +86,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device) batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch) translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0]) assert self.tgt_text == decoded
self.assertEqual(self.tgt_text[1], decoded[1])
def test_mbart_enro_config(self): def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"] mbart_models = ["facebook/mbart-large-en-ro"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment