Unverified Commit 2544c143 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate Tests] Make sure no tokens are force-generated (#18053)

parent 91c4a3ab
...@@ -116,6 +116,12 @@ class BartModelTester: ...@@ -116,6 +116,12 @@ class BartModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
...@@ -145,6 +151,8 @@ class BartModelTester: ...@@ -145,6 +151,8 @@ class BartModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def get_pipeline_config(self): def get_pipeline_config(self):
......
...@@ -107,6 +107,12 @@ class BlenderbotModelTester: ...@@ -107,6 +107,12 @@ class BlenderbotModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3, 3,
...@@ -135,6 +141,8 @@ class BlenderbotModelTester: ...@@ -135,6 +141,8 @@ class BlenderbotModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def get_pipeline_config(self): def get_pipeline_config(self):
......
...@@ -107,6 +107,12 @@ class BlenderbotSmallModelTester: ...@@ -107,6 +107,12 @@ class BlenderbotSmallModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3, 3,
...@@ -135,6 +141,8 @@ class BlenderbotSmallModelTester: ...@@ -135,6 +141,8 @@ class BlenderbotSmallModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -123,6 +123,12 @@ class MarianModelTester: ...@@ -123,6 +123,12 @@ class MarianModelTester:
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.decoder_start_token_id = decoder_start_token_id self.decoder_start_token_id = decoder_start_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3, 3,
...@@ -152,6 +158,8 @@ class MarianModelTester: ...@@ -152,6 +158,8 @@ class MarianModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -113,6 +113,12 @@ class MBartModelTester: ...@@ -113,6 +113,12 @@ class MBartModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
...@@ -142,6 +148,8 @@ class MBartModelTester: ...@@ -142,6 +148,8 @@ class MBartModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -104,6 +104,12 @@ class PegasusModelTester: ...@@ -104,6 +104,12 @@ class PegasusModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
...@@ -151,6 +157,8 @@ class PegasusModelTester: ...@@ -151,6 +157,8 @@ class PegasusModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment