Unverified Commit 18abc756 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Import tensorflow inside relevant methods in `trainer_utils` (#26106)

import tensorflow inside relevant methods in trainer_utils
parent 9cccb3a8
...@@ -46,9 +46,6 @@ from .utils import ( ...@@ -46,9 +46,6 @@ from .utils import (
if is_torch_available(): if is_torch_available():
import torch import torch
if is_tf_available():
import tensorflow as tf
def seed_worker(_): def seed_worker(_):
""" """
...@@ -80,6 +77,8 @@ def enable_full_determinism(seed: int, warn_only: bool = False): ...@@ -80,6 +77,8 @@ def enable_full_determinism(seed: int, warn_only: bool = False):
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
if is_tf_available(): if is_tf_available():
import tensorflow as tf
tf.config.experimental.enable_op_determinism() tf.config.experimental.enable_op_determinism()
...@@ -101,6 +100,8 @@ def set_seed(seed: int): ...@@ -101,6 +100,8 @@ def set_seed(seed: int):
if is_torch_xpu_available(): if is_torch_xpu_available():
torch.xpu.manual_seed_all(seed) torch.xpu.manual_seed_all(seed)
if is_tf_available(): if is_tf_available():
import tensorflow as tf
tf.random.set_seed(seed) tf.random.set_seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment