Unverified Commit d9b4bfb5 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

Add THD + GQA supports (#1260)



Add THD + GQA supports for cuDNN >= 9.6
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 35f7d262
...@@ -305,11 +305,14 @@ class FusedAttnRunner: ...@@ -305,11 +305,14 @@ class FusedAttnRunner:
]: ]:
pytest.skip("THD format requires padding masks.") pytest.skip("THD format requires padding masks.")
if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD: qkv_format = get_qkv_format(self.qkv_layout)
if self.num_heads_q != self.num_heads_kv: if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
if self.max_seqlen_q != self.max_seqlen_kv: if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None: if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip( pytest.skip(
......
...@@ -181,10 +181,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -181,10 +181,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format // qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
qkv_format == NVTE_QKV_Format::NVTE_THD) || (cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
(qkv_format == NVTE_QKV_Format::NVTE_BSHD)) && (cudnn_runtime_version >= 90600))) &&
// sliding window // sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 && ((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) || (window_size_right == -1 || window_size_right == 0)) ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment