Unverified Commit d978e800 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix attention mask type for Flash Attention + CP + THD (#1354)



* always have padding mask type for both flash and fused attentions
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an redundant assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
parent 8c004241
......@@ -42,7 +42,7 @@ def run_dpa_with_cp(
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
......
......@@ -4309,14 +4309,6 @@ def attn_forward_func_with_cp(
assert (
qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert (
qkv_format != "thd"
or not use_fused_attention
or attn_mask_type in ["padding", "padding_causal"]
), (
f"Context parallelism is not supported for {attn_mask_type} mask type and "
f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
)
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
......@@ -7878,6 +7870,9 @@ class DotProductAttention(TransformerEngineBaseModule):
), f"Values have head_dim = {value_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"
if qkv_format is None:
qkv_format = self.qkv_format
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
......@@ -7904,9 +7899,6 @@ class DotProductAttention(TransformerEngineBaseModule):
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if qkv_format is None:
qkv_format = self.qkv_format
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment