Unverified Commit c334fc46 authored by zhujian's avatar zhujian Committed by GitHub
Browse files

[PyTorch] Support FA3 for MLA and with CP (#1907)



feature(FA3,MLA,CP):
1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc
Signed-off-by: default avatarzhujian <zhujian.whu.cs@gmail.com>
parent 8aee1bb7
......@@ -390,7 +390,7 @@
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n",
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
......
......@@ -36,6 +36,12 @@ model_configs_flash_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA
"cp_3_2": ModelConfig(
2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA
}
......@@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run(
get_bash_arguments(
......
......@@ -126,10 +126,10 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py"""
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
v3_warning_printed = False
@staticmethod
......@@ -477,11 +477,10 @@ def get_attention_backend(
# Filter: Head dimension
if head_dim_qk != head_dim_v:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
......@@ -508,9 +507,40 @@ def get_attention_backend(
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if use_flash_attention_3:
def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if head_dim_qk > 256 or num_heads % num_gqa_groups != 0:
return False
if head_dim_qk != head_dim_v:
cond1 = 128 < head_dim_qk <= 192
cond2 = 96 < head_dim_v <= 128
cond3 = head_dim_qk <= 64 and head_dim_v <= 512
if not ((cond1 and cond2) or cond3):
return False
if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16):
return False
return True
if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128")
logger.debug(
"Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, "
"head_dim_qk, head_dim_v or qkv_dtype. "
"Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and "
"if head_dim_qk is different from head_dim_v, then "
"(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or "
"(head_dim_qk <= 64 and head_dim_v <= 512), and "
"if head_dim_qk is different from head_dim_v and head_dim_v > 256, then "
"qkv_dtype requires fp16 and bf16 data type. "
"Found: num_heads = %s, num_gqa_groups = %s, "
"head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.",
num_heads,
num_gqa_groups,
head_dim_qk,
head_dim_v,
qkv_dtype,
)
use_flash_attention_3 = False
# Filter: QKV layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment