Unverified Commit 2de5cb12 authored by Matt's avatar Matt Committed by GitHub
Browse files

Use the Keras set_random_seed in tests (#30504)

Use the Keras set_random_seed to ensure reproducible weight initialization
parent 20081c74
...@@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase):
@slow @slow
@require_tf @require_tf
def test_load_default_pipelines_tf(self): def test_load_default_pipelines_tf(self):
import tensorflow as tf from transformers.modeling_tf_utils import keras
from transformers.pipelines import SUPPORTED_TASKS from transformers.pipelines import SUPPORTED_TASKS
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731 set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731
for task in SUPPORTED_TASKS.keys(): for task in SUPPORTED_TASKS.keys():
if task == "table-question-answering": if task == "table-question-answering":
# test table in seperate test due to more dependencies # test table in seperate test due to more dependencies
...@@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf) self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by TF
gc.collect() gc.collect()
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment