Unverified Commit 10d232e8 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Add deterministic config to `set_seed` (#29778)

* Add deterministic config

* Add note on slowdown

* English fails me again
parent f0bfb150
......@@ -82,12 +82,15 @@ def enable_full_determinism(seed: int, warn_only: bool = False):
tf.config.experimental.enable_op_determinism()
def set_seed(seed: int):
def set_seed(seed: int, deterministic: bool = False):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
Args:
seed (`int`): The seed to set.
seed (`int`):
The seed to set.
deterministic (`bool`, *optional*, defaults to `False`):
Whether to use deterministic algorithms where available. Can slow down training.
"""
random.seed(seed)
np.random.seed(seed)
......@@ -95,6 +98,8 @@ def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if deterministic:
torch.use_deterministic_algorithms(True)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
......@@ -103,6 +108,8 @@ def set_seed(seed: int):
import tensorflow as tf
tf.random.set_seed(seed)
if deterministic:
tf.config.experimental.enable_op_determinism()
def neftune_post_forward_hook(module, input, output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment