Unverified Commit 7804d116 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Disable FAv2 for deterministic use (#366)



* Disable FAv2 for deterministic use
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Also disable FusedAttention backend with deterministic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 86d148f9
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Attention.""" """Attention."""
import os import os
import warnings
import math import math
from importlib.metadata import version from importlib.metadata import version
from contextlib import nullcontext from contextlib import nullcontext
...@@ -355,6 +356,7 @@ class FlashAttention(torch.nn.Module): ...@@ -355,6 +356,7 @@ class FlashAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
deterministic: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -366,7 +368,7 @@ class FlashAttention(torch.nn.Module): ...@@ -366,7 +368,7 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) self.deterministic = deterministic
def forward( def forward(
self, self,
...@@ -795,10 +797,20 @@ class DotProductAttention(torch.nn.Module): ...@@ -795,10 +797,20 @@ class DotProductAttention(torch.nn.Module):
norm_factor = math.sqrt(self.hidden_size_per_attention_head) norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.device_compute_capability = get_device_compute_capability() self.device_compute_capability = get_device_compute_capability()
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
self.use_flash_attention = ( self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= 8.0
) )
if _flash_attn_2_available and self.deterministic:
self.use_flash_attention = False
warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic"
"exection. In order to use FA with deterministic behavior, install FlashAttention"
"version 1."
)
self.use_fused_attention = ( self.use_fused_attention = (
int(os.getenv("NVTE_FUSED_ATTN", "1")) int(os.getenv("NVTE_FUSED_ATTN", "1"))
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= 8.0
...@@ -814,7 +826,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -814,7 +826,9 @@ class DotProductAttention(torch.nn.Module):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
if self.use_flash_attention: if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs) self.flash_attention = FlashAttention(
norm_factor, **attn_kwargs,
deterministic=self.deterministic)
# Instantiating three types since use of flash-attn and FusedAttention # Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
if self.use_fused_attention: if self.use_fused_attention:
...@@ -951,6 +965,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -951,6 +965,13 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = (use_fused_attention use_fused_attention = (use_fused_attention
and is_backend_avail and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads) and self.num_gqa_groups == self.num_attention_heads)
if (self.deterministic
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
use_fused_attention = False
warnings.warn(
"Disabling usage of FusedAttention since the FusedAttention"
"backend does not support deterministic exection."
)
if use_flash_attention: if use_flash_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment