Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e4a8223
Unverified
Commit
5e4a8223
authored
Oct 02, 2025
by
vllmellm
Committed by
GitHub
Oct 02, 2025
Browse files
[Qwen][ROCm] Flash Attention Rotary Embeddings (#24642)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
e51de388
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
5 deletions
+28
-5
vllm/model_executor/layers/rotary_embedding/common.py
vllm/model_executor/layers/rotary_embedding/common.py
+23
-0
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+5
-5
No files found.
vllm/model_executor/layers/rotary_embedding/common.py
View file @
5e4a8223
...
@@ -2,15 +2,21 @@
...
@@ -2,15 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
functools
import
cache
from
importlib.util
import
find_spec
from
typing
import
Callable
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
logger
=
init_logger
(
__name__
)
# common functions
# common functions
def
rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
...
@@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
return
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
return
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
@
cache
def
dispatch_rotary_emb_function
()
->
Callable
[...,
torch
.
Tensor
]:
if
current_platform
.
is_cuda
():
return
apply_rotary_emb
if
current_platform
.
is_rocm
():
if
find_spec
(
"flash_attn"
)
is
not
None
:
from
flash_attn.ops.triton.rotary
import
apply_rotary
return
apply_rotary
else
:
logger
.
warning
(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
return
apply_rotary_emb_torch
# yarn functions
# yarn functions
# Inverse dim formula to find dim based on number of rotations
# Inverse dim formula to find dim based on number of rotations
def
yarn_find_correction_dim
(
num_rotations
:
int
,
def
yarn_find_correction_dim
(
num_rotations
:
int
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
5e4a8223
...
@@ -50,6 +50,8 @@ from vllm.model_executor.layers.activation import QuickGELU
...
@@ -50,6 +50,8 @@ from vllm.model_executor.layers.activation import QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding.common
import
(
dispatch_rotary_emb_function
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -63,7 +65,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -63,7 +65,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
@@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
...
@@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rotary_emb_function
=
dispatch_rotary_emb_function
()
t_
=
t
.
float
()
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
apply_rotary_emb
=
apply_rotary_emb_torch
output
=
rotary_emb_function
(
t_
,
cos
,
sin
).
type_as
(
t
)
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
output
=
apply_rotary_emb
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment