Commit 374deef4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fixed typo

parent a2c8e516
...@@ -855,7 +855,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -855,7 +855,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len = 1 cur_len = 1
# put model in generation mode if it has one # put model in generation mode if it has one
if hasattr(self.model, "generation_mode"): if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
self.model.decoder.generation_mode = True self.model.decoder.generation_mode = True
else: else:
encoder_inputs = None encoder_inputs = None
......
...@@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
new_input_ids = lm_model.generate( new_input_ids = lm_model.generate(
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
) )
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1)) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
# TODO(SS): uneven length batches, empty inputs # TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self): def test_shift_tokens_right(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment