"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "48be2ee2531e5cdd8d9b2fe5c825cdaf95601908"
Unverified Commit 036d3de2 authored by 조준래's avatar 조준래 Committed by GitHub
Browse files

add flash-attn deterministic option to flash-attn>=2.4.1 (#31961)



* add flash-attn deterministic option to flash-attn>=2.4.1

* Add Missing Import

* Fix ruff linting issues

* Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal`

---------
Co-authored-by: default avatarjun.4 <jun.4@kakaobrain.com>
parent 89eec5cf
......@@ -14,12 +14,13 @@
# limitations under the License.
import inspect
import os
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from .utils import is_flash_attn_2_available
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
if is_flash_attn_2_available():
......@@ -141,6 +142,7 @@ def _flash_attention_forward(
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1",
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
......@@ -164,6 +166,8 @@ def _flash_attention_forward(
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
softcap (`float`, *optional*):
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
"""
if not use_top_left_mask:
causal = is_causal
......@@ -177,6 +181,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"):
flash_kwargs["deterministic"] = deterministic
if softcap is not None:
flash_kwargs["softcap"] = softcap
......
......@@ -74,6 +74,8 @@ def enable_full_determinism(seed: int, warn_only: bool = False):
# The environment variable required to enable deterministic mode on Ascend NPUs.
os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
os.environ["HCCL_DETERMINISTIC"] = "1"
os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
torch.use_deterministic_algorithms(True, warn_only=warn_only)
# Enable CUDNN deterministic mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment