Unverified Commit 8fb7c908 authored by Daniel Suess's avatar Daniel Suess Committed by GitHub
Browse files

Fix failing tests for XLA generation in TF (#18298)

* Fix failing test_xla_generate_slow tests

* Fix failing speech-to-text xla_generate tests
parent a507908c
...@@ -1685,6 +1685,17 @@ class TFModelTesterMixin: ...@@ -1685,6 +1685,17 @@ class TFModelTesterMixin:
config.do_sample = False config.do_sample = False
config.num_beams = num_beams config.num_beams = num_beams
config.num_return_sequences = num_return_sequences config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config) model = model_class(config)
if model.supports_xla_generation: if model.supports_xla_generation:
...@@ -1714,15 +1725,6 @@ class TFModelTesterMixin: ...@@ -1714,15 +1725,6 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
# TODO (Joao): find the issues related to the following models. They are passing the fast test, but failing
# the slow one.
if any(
[
model in str(self).lower()
for model in ["tfbart", "tfblenderbot", "tfmarian", "tfmbart", "tfopt", "tfpegasus"]
]
):
return
num_beams = 8 num_beams = 8
num_return_sequences = 2 num_return_sequences = 2
max_length = 128 max_length = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment