Unverified Commit 8678879f authored by 조준래's avatar 조준래 Committed by GitHub
Browse files

fix: default value reflects the runtime environment variables rather than the...

fix: default value reflects the runtime environment variables rather than the ones present at import time. (#32153)

* fix: default value reflects the runtime environment variables rather than the ones present at import time.

* Fix: Change `deterministic` to None by default; use env var if None
parent 01be5b48
......@@ -193,7 +193,7 @@ def _flash_attention_forward(
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1",
deterministic: bool = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
......@@ -233,6 +233,8 @@ def _flash_attention_forward(
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic
if softcap is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment