"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "88a94cc50d057dbf9b4ecb0bc39159a225f5f780"
Unverified Commit 93c5c65b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add THD support for max_logit/MuonClip (#2480)



* update FE; initial pass at thd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* produce Stats+Max instead of Max+Sum_Exp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "produce Stats+Max instead of Max+Sum_Exp"

This reverts commit c7d2b77b2da9ff3f68344097284187ac427eeb6a.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent e411547b
Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93
...@@ -1101,7 +1101,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1101,7 +1101,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr; output_Max->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else { } else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
...@@ -1109,7 +1109,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1109,7 +1109,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr; output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else { } else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
...@@ -1118,7 +1118,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1118,7 +1118,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else { } else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
......
...@@ -532,9 +532,6 @@ def get_attention_backend( ...@@ -532,9 +532,6 @@ def get_attention_backend(
if use_flash_attention: if use_flash_attention:
use_flash_attention = False use_flash_attention = False
logger.debug("Disabling FlashAttention for max_logit") logger.debug("Disabling FlashAttention for max_logit")
if use_fused_attention and qkv_format == "thd":
use_fused_attention = False
logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd")
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
...@@ -677,9 +674,6 @@ def get_attention_backend( ...@@ -677,9 +674,6 @@ def get_attention_backend(
# Filter: QKV layout # Filter: QKV layout
if qkv_format == "thd": if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if pad_between_seqs: if pad_between_seqs:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment