Unverified Commit 122de2cc authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Relax checks for attn_mask_type in FlashAttention (#226)



* relax attn mask type checks for FlashAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable flash attn if mask tensor is not None
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix the logic for flash attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent a5f61ce2
...@@ -281,9 +281,6 @@ class FlashAttention(torch.nn.Module): ...@@ -281,9 +281,6 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
_flash_attn_version >= _flash_attn_version_required _flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required." ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
self.attn_causal_mask = attn_mask_type == "causal" self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor self.norm_factor = norm_factor
...@@ -296,7 +293,6 @@ class FlashAttention(torch.nn.Module): ...@@ -296,7 +293,6 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -308,9 +304,6 @@ class FlashAttention(torch.nn.Module): ...@@ -308,9 +304,6 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FlashAttention currently only supports CUDA tensors.' ), 'FlashAttention currently only supports CUDA tensors.'
assert (
attention_mask is None
), 'FlashAttention currently does not support external attention mask.'
# For now just 128, will make it more general in the future # For now just 128, will make it more general in the future
...@@ -428,7 +421,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -428,7 +421,6 @@ class DotProductAttention(torch.nn.Module):
self.device_compute_capability = get_device_compute_capability() self.device_compute_capability = get_device_compute_capability()
self.use_flash_attention = ( self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attn_mask_type == "causal"
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= 8.0
) )
...@@ -437,6 +429,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -437,6 +429,7 @@ class DotProductAttention(torch.nn.Module):
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type, "attn_mask_type": attn_mask_type,
} }
self.attn_mask_type = attn_mask_type
if self.use_flash_attention: if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs) self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
...@@ -514,6 +507,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -514,6 +507,9 @@ class DotProductAttention(torch.nn.Module):
): ):
use_flash_attention = False use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment