Unverified Commit 59d5edef authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Avoid flaky generation sampling tests (#21445)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 31c351c4
...@@ -780,7 +780,7 @@ class GenerationTesterMixin: ...@@ -780,7 +780,7 @@ class GenerationTesterMixin:
forced_eos_token_id=model.config.forced_eos_token_id, forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length, max_length=max_length,
) )
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate( output_sample, output_generate = self._sample_generate(
......
...@@ -621,7 +621,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt ...@@ -621,7 +621,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt
config.forced_eos_token_id = None config.forced_eos_token_id = None
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
num_return_sequences = 2 num_return_sequences = 2
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -670,7 +670,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt ...@@ -670,7 +670,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt
config.eos_token_id = None config.eos_token_id = None
config.forced_eos_token_id = None config.forced_eos_token_id = None
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment