Commit f1c71da1 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix eos_token_ids in test

parent 6047f46b
......@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.eos_token_id = 2
self.eos_token_ids = [2]
self.pad_token_id = 1
self.bos_token_id = 0
torch.manual_seed(0)
......@@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase):
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_id,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
......@@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_id,
decoder_start_token_id=hf.config.eos_token_ids[0],
)
decoded = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment