"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "952b9131a21b03691c5086b0f32f11d927664755"
Unverified Commit 10d232e8 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Add deterministic config to `set_seed` (#29778)

* Add deterministic config

* Add note on slowdown

* English fails me again
parent f0bfb150
...@@ -82,12 +82,15 @@ def enable_full_determinism(seed: int, warn_only: bool = False): ...@@ -82,12 +82,15 @@ def enable_full_determinism(seed: int, warn_only: bool = False):
tf.config.experimental.enable_op_determinism() tf.config.experimental.enable_op_determinism()
def set_seed(seed: int): def set_seed(seed: int, deterministic: bool = False):
""" """
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
Args: Args:
seed (`int`): The seed to set. seed (`int`):
The seed to set.
deterministic (`bool`, *optional*, defaults to `False`):
Whether to use deterministic algorithms where available. Can slow down training.
""" """
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -95,6 +98,8 @@ def set_seed(seed: int): ...@@ -95,6 +98,8 @@ def set_seed(seed: int):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available # ^^ safe to call this function even if cuda is not available
if deterministic:
torch.use_deterministic_algorithms(True)
if is_torch_npu_available(): if is_torch_npu_available():
torch.npu.manual_seed_all(seed) torch.npu.manual_seed_all(seed)
if is_torch_xpu_available(): if is_torch_xpu_available():
...@@ -103,6 +108,8 @@ def set_seed(seed: int): ...@@ -103,6 +108,8 @@ def set_seed(seed: int):
import tensorflow as tf import tensorflow as tf
tf.random.set_seed(seed) tf.random.set_seed(seed)
if deterministic:
tf.config.experimental.enable_op_determinism()
def neftune_post_forward_hook(module, input, output): def neftune_post_forward_hook(module, input, output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment