"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b"
Unverified Commit e1eb3efd authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Temporarily increase tol for PT-FLAX whisper tests (#23288)

parent b3bbe1bd
...@@ -248,6 +248,10 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -248,6 +248,10 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# We override with a slightly higher tol value, as test recently became flaky
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
# overwrite because of `input_features` # overwrite because of `input_features`
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self): def test_save_load_bf16_to_base_pt(self):
......
...@@ -828,6 +828,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -828,6 +828,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# We override with a slightly higher tol value, as test recently became flaky # We override with a slightly higher tol value, as test recently became flaky
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# We override with a slightly higher tol value, as test recently became flaky
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self): def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment