Unverified Commit 2de5cb12 authored by Matt's avatar Matt Committed by GitHub
Browse files

Use the Keras set_random_seed in tests (#30504)

Use the Keras set_random_seed to ensure reproducible weight initialization
parent 20081c74
......@@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase):
@slow
@require_tf
def test_load_default_pipelines_tf(self):
import tensorflow as tf
from transformers.modeling_tf_utils import keras
from transformers.pipelines import SUPPORTED_TASKS
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731
for task in SUPPORTED_TASKS.keys():
if task == "table-question-answering":
# test table in seperate test due to more dependencies
......@@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by TF
gc.collect()
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment