Unverified Commit b13a4475 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (#27748)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent 7956b0c0
...@@ -77,7 +77,11 @@ def dispatch_rotary_emb_function( ...@@ -77,7 +77,11 @@ def dispatch_rotary_emb_function(
if current_platform.is_cuda(): if current_platform.is_cuda():
return apply_rotary_emb return apply_rotary_emb
if current_platform.is_rocm(): # if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None: if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary from flash_attn.ops.triton.rotary import apply_rotary
...@@ -87,10 +91,9 @@ def dispatch_rotary_emb_function( ...@@ -87,10 +91,9 @@ def dispatch_rotary_emb_function(
"flash_attn is not installed. Falling back to PyTorch " "flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings." "implementation for rotary embeddings."
) )
if default is not None: if default is not None:
return default return default
else:
return apply_rotary_emb_torch return apply_rotary_emb_torch
......
...@@ -370,7 +370,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -370,7 +370,7 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0, dropout_p=0.0,
causal=False, causal=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment