"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "801aaa55087d80a5d736b619bce0344af25e9883"
Unverified Commit e8f44af5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[generate] do_sample default back to False (#3298)

* change do_samples back

* None better default as boolean

* adapt do_sample to True in test example

* make style
parent 2187c49f
...@@ -35,7 +35,6 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -35,7 +35,6 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
min_length=min_length + 1, # +1 from original because we start at step=1 min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
early_stopping=True, early_stopping=True,
do_sample=False,
decoder_start_token_id=model.config.eos_token_ids[0], decoder_start_token_id=model.config.eos_token_ids[0],
) )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
......
...@@ -460,8 +460,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -460,8 +460,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids=None, input_ids=None,
max_length=None, max_length=None,
min_length=None, min_length=None,
do_sample=True, do_sample=None,
early_stopping=False, early_stopping=None,
num_beams=None, num_beams=None,
temperature=None, temperature=None,
top_k=None, top_k=None,
...@@ -494,7 +494,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -494,7 +494,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
The max length of the sequence to be generated. Between 1 and infinity. Default to 20. The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
do_sample: (`optional`) bool do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`. If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`.
num_beams: (`optional`) int num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
......
...@@ -656,8 +656,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -656,8 +656,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids=None, input_ids=None,
max_length=None, max_length=None,
min_length=None, min_length=None,
do_sample=True, do_sample=None,
early_stopping=False, early_stopping=None,
num_beams=None, num_beams=None,
temperature=None, temperature=None,
top_k=None, top_k=None,
...@@ -691,7 +691,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -691,7 +691,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
The max length of the sequence to be generated. Between 1 and infinity. Default to 20. The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
do_sample: (`optional`) bool do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`. If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`.
num_beams: (`optional`) int num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
......
...@@ -285,7 +285,12 @@ class BartHeadTests(unittest.TestCase): ...@@ -285,7 +285,12 @@ class BartHeadTests(unittest.TestCase):
max_length = 5 max_length = 5
new_input_ids = lm_model.generate( new_input_ids = lm_model.generate(
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length input_ids.clone(),
do_sample=True,
num_return_sequences=1,
num_beams=2,
no_repeat_ngram_size=3,
max_length=max_length,
) )
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1)) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
# TODO(SS): uneven length batches, empty inputs # TODO(SS): uneven length batches, empty inputs
......
...@@ -638,16 +638,16 @@ class ModelTesterMixin: ...@@ -638,16 +638,16 @@ class ModelTesterMixin:
if config.bos_token_id is None: if config.bos_token_id is None:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
model.generate(max_length=5) model.generate(do_sample=True, max_length=5)
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(input_ids)) self._check_generated_tokens(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, num_beams=3)) self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
else: else:
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(max_length=5)) self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation # generating multiple sequences when greedy no beam generation
...@@ -659,12 +659,14 @@ class ModelTesterMixin: ...@@ -659,12 +659,14 @@ class ModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample # batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy # batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False)) self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample # batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) self._check_generated_tokens(
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
# batch_size > 1, num_beams > 1, greedy # batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens( self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
......
...@@ -422,16 +422,16 @@ class TFModelTesterMixin: ...@@ -422,16 +422,16 @@ class TFModelTesterMixin:
if config.bos_token_id is None: if config.bos_token_id is None:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
model.generate(max_length=5) model.generate(do_sample=True, max_length=5)
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(input_ids)) self._check_generated_tokens(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, num_beams=3)) self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
else: else:
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(max_length=5)) self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation # generating multiple sequences when greedy no beam generation
...@@ -443,12 +443,14 @@ class TFModelTesterMixin: ...@@ -443,12 +443,14 @@ class TFModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample # batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy # batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False)) self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample # batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) self._check_generated_tokens(
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
# batch_size > 1, num_beams > 1, greedy # batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens( self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment