Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f38f8c97
Unverified
Commit
f38f8c97
authored
Feb 24, 2026
by
Rohan Potdar
Committed by
GitHub
Feb 25, 2026
Browse files
[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
ec1d30c0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
139 additions
and
67 deletions
+139
-67
tests/compile/passes/test_rope_kvcache_fusion.py
tests/compile/passes/test_rope_kvcache_fusion.py
+9
-1
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+65
-36
vllm/compilation/passes/fusion/matcher_utils.py
vllm/compilation/passes/fusion/matcher_utils.py
+5
-0
vllm/compilation/passes/utility/scatter_split_replace.py
vllm/compilation/passes/utility/scatter_split_replace.py
+5
-1
vllm/config/compilation.py
vllm/config/compilation.py
+13
-16
vllm/config/vllm.py
vllm/config/vllm.py
+20
-3
vllm/envs.py
vllm/envs.py
+3
-3
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+11
-7
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+8
-0
No files found.
tests/compile/passes/test_rope_kvcache_fusion.py
View file @
f38f8c97
...
@@ -177,7 +177,10 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
...
@@ -177,7 +177,10 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
def
ops_in_model_before
(
self
)
->
list
[
torch
.
_ops
.
OpOverload
]:
def
ops_in_model_before
(
self
)
->
list
[
torch
.
_ops
.
OpOverload
]:
ops
=
[]
ops
=
[]
if
self
.
enable_rope_custom_op
:
if
self
.
enable_rope_custom_op
:
ops
.
append
(
ROTARY_OP
)
if
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
():
ops
.
append
(
torch
.
ops
.
vllm
.
rocm_aiter_triton_rotary_embedding
.
default
)
else
:
ops
.
append
(
ROTARY_OP
)
else
:
else
:
ops
.
append
(
INDEX_SELECT_OP
)
ops
.
append
(
INDEX_SELECT_OP
)
ops
.
append
(
torch
.
ops
.
vllm
.
unified_kv_cache_update
.
default
)
ops
.
append
(
torch
.
ops
.
vllm
.
unified_kv_cache_update
.
default
)
...
@@ -196,6 +199,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
...
@@ -196,6 +199,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"enable_rope_custom_op"
,
[
True
])
# [True, False])
@
pytest
.
mark
.
parametrize
(
"enable_rope_custom_op"
,
[
True
])
# [True, False])
@
pytest
.
mark
.
parametrize
(
"enable_aiter_triton_rope"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
])
...
@@ -210,6 +214,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
...
@@ -210,6 +214,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
def
test_rope_kvcache_fusion
(
def
test_rope_kvcache_fusion
(
attn_backend
:
AttentionBackendEnum
,
attn_backend
:
AttentionBackendEnum
,
enable_rope_custom_op
:
bool
,
enable_rope_custom_op
:
bool
,
enable_aiter_triton_rope
:
bool
,
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(
...
@@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER_TRITON_ROPE"
,
"1"
if
enable_aiter_triton_rope
else
"0"
)
rocm_aiter_ops
.
refresh_env_variables
()
rocm_aiter_ops
.
refresh_env_variables
()
model
=
QKRoPEKVCacheTestModel
(
model
=
QKRoPEKVCacheTestModel
(
...
...
vllm/_aiter_ops.py
View file @
f38f8c97
...
@@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
...
@@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return
out
,
residual_out
return
out
,
residual_out
def
_triton_rotary_embedding_impl
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
# Modifies query and key in-place
from
aiter.ops.triton.rope.rope
import
(
rope_cached_thd_positions_offsets_2c_fwd_inplace
,
)
num_tokens
=
positions
.
numel
()
cos
,
sin
=
cos_sin_cache
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
key_shape
=
key
.
shape
rotate_style
=
0
if
is_neox
else
1
rotary_dim
=
head_size
query
=
query
.
view
(
num_tokens
,
-
1
,
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
head_size
)
query_
=
query
[...,
:
rotary_dim
]
key_
=
key
[...,
:
rotary_dim
]
positions
=
positions
.
view
(
*
query
.
shape
[:
1
])
rope_cached_thd_positions_offsets_2c_fwd_inplace
(
query_
,
key_
,
cos
,
sin
,
positions
,
offsets
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
False
,
)
query
=
query
.
view
(
query_shape
)
key
=
key
.
view
(
key_shape
)
def
_triton_rotary_embedding_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
return
# Global flag to ensure ops are registered only once
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
_OPS_REGISTERED
=
False
...
@@ -1178,6 +1231,14 @@ class rocm_aiter_ops:
...
@@ -1178,6 +1231,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
# Register rocm aiter rotary embedding custom op
direct_register_custom_op
(
op_name
=
"rocm_aiter_triton_rotary_embedding"
,
op_func
=
_triton_rotary_embedding_impl
,
mutates_args
=
[
"query"
,
"key"
],
# These tensors are modified in-place
fake_impl
=
_triton_rotary_embedding_fake
,
)
_OPS_REGISTERED
=
True
_OPS_REGISTERED
=
True
@
staticmethod
@
staticmethod
...
@@ -1220,6 +1281,10 @@ class rocm_aiter_ops:
...
@@ -1220,6 +1281,10 @@ class rocm_aiter_ops:
def
get_triton_add_rmsnorm_pad_op
()
->
OpOverload
:
def
get_triton_add_rmsnorm_pad_op
()
->
OpOverload
:
return
torch
.
ops
.
vllm
.
rocm_aiter_triton_add_rmsnorm_pad
.
default
return
torch
.
ops
.
vllm
.
rocm_aiter_triton_add_rmsnorm_pad
.
default
@
staticmethod
def
get_triton_rotary_embedding_op
()
->
OpOverload
:
return
torch
.
ops
.
vllm
.
rocm_aiter_triton_rotary_embedding
.
default
@
staticmethod
@
staticmethod
def
rms_norm
(
def
rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
...
@@ -1482,42 +1547,6 @@ class rocm_aiter_ops:
...
@@ -1482,42 +1547,6 @@ class rocm_aiter_ops:
gemm_afp4wfp4
(
x_q
,
weight
,
x_s
,
weight_scale
.
T
,
out_dtype
,
y
)
gemm_afp4wfp4
(
x_q
,
weight
,
x_s
,
weight_scale
.
T
,
out_dtype
,
y
)
return
y
return
y
@
staticmethod
def
triton_rotary_embed
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
head_size
:
int
,
rotary_dim
:
int
,
is_neox_style
:
bool
,
):
from
aiter.ops.triton.rope
import
rope_cached_thd_positions_2c_fwd_inplace
num_tokens
=
positions
.
numel
()
cos
,
sin
=
cos_sin_cache
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
key_shape
=
key
.
shape
rotate_style
=
0
if
is_neox_style
else
1
query
=
query
.
view
(
num_tokens
,
-
1
,
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
head_size
)
query_
=
query
[...,
:
rotary_dim
]
key_
=
key
[...,
:
rotary_dim
]
positions
=
positions
.
view
(
*
query
.
shape
[:
1
])
rope_cached_thd_positions_2c_fwd_inplace
(
query_
,
key_
,
cos
,
sin
,
positions
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
False
,
)
query
=
query
.
view
(
query_shape
)
key
=
key
.
view
(
key_shape
)
@
staticmethod
@
staticmethod
def
triton_rope_and_cache
(
def
triton_rope_and_cache
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
vllm/compilation/passes/fusion/matcher_utils.py
View file @
f38f8c97
...
@@ -89,10 +89,13 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
...
@@ -89,10 +89,13 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
use_flashinfer
:
bool
=
False
,
use_flashinfer
:
bool
=
False
,
match_rocm_aiter
:
bool
|
None
=
None
,
enabled
:
bool
|
None
=
None
,
enabled
:
bool
|
None
=
None
,
)
->
None
:
)
->
None
:
if
enabled
is
None
:
if
enabled
is
None
:
enabled
=
RotaryEmbedding
.
enabled
()
enabled
=
RotaryEmbedding
.
enabled
()
if
match_rocm_aiter
is
None
:
match_rocm_aiter
=
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
()
super
().
__init__
(
enabled
)
super
().
__init__
(
enabled
)
self
.
is_neox
=
is_neox
self
.
is_neox
=
is_neox
...
@@ -104,6 +107,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
...
@@ -104,6 +107,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
self
.
rotary_dim
=
head_size
self
.
rotary_dim
=
head_size
if
use_flashinfer
:
if
use_flashinfer
:
self
.
rotary_op
=
FLASHINFER_ROTARY_OP
self
.
rotary_op
=
FLASHINFER_ROTARY_OP
elif
match_rocm_aiter
:
self
.
rotary_op
=
rocm_aiter_ops
.
get_triton_rotary_embedding_op
()
else
:
else
:
self
.
rotary_op
=
ROTARY_OP
self
.
rotary_op
=
ROTARY_OP
...
...
vllm/compilation/passes/utility/scatter_split_replace.py
View file @
f38f8c97
...
@@ -60,6 +60,10 @@ class ScatterSplitReplacementPass(VllmInductorPass):
...
@@ -60,6 +60,10 @@ class ScatterSplitReplacementPass(VllmInductorPass):
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
count
=
0
count
=
0
target_ops
=
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
]
if
hasattr
(
torch
.
ops
.
vllm
,
"rocm_aiter_triton_rotary_embedding"
):
target_ops
.
append
(
torch
.
ops
.
vllm
.
rocm_aiter_triton_rotary_embedding
.
default
)
for
node
in
graph
.
nodes
:
for
node
in
graph
.
nodes
:
if
not
is_func
(
node
,
auto_functionalized
):
if
not
is_func
(
node
,
auto_functionalized
):
continue
continue
...
@@ -67,7 +71,7 @@ class ScatterSplitReplacementPass(VllmInductorPass):
...
@@ -67,7 +71,7 @@ class ScatterSplitReplacementPass(VllmInductorPass):
kwargs
=
node
.
kwargs
kwargs
=
node
.
kwargs
at_target
=
node
.
args
[
0
]
at_target
=
node
.
args
[
0
]
if
at_target
==
torch
.
ops
.
_C
.
rotary_embedding
.
default
:
if
at_target
in
target_ops
:
query
=
kwargs
[
"query"
]
query
=
kwargs
[
"query"
]
key
=
kwargs
[
"key"
]
key
=
kwargs
[
"key"
]
getitem_nodes
=
{}
getitem_nodes
=
{}
...
...
vllm/config/compilation.py
View file @
f38f8c97
...
@@ -123,6 +123,8 @@ class PassConfig:
...
@@ -123,6 +123,8 @@ class PassConfig:
"""Enable async TP."""
"""Enable async TP."""
fuse_allreduce_rms
:
bool
=
Field
(
default
=
None
)
fuse_allreduce_rms
:
bool
=
Field
(
default
=
None
)
"""Enable flashinfer allreduce fusion."""
"""Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion
:
bool
=
False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions
# ROCm/AITER specific fusions
fuse_act_padding
:
bool
=
Field
(
default
=
None
)
fuse_act_padding
:
bool
=
Field
(
default
=
None
)
...
@@ -153,8 +155,6 @@ class PassConfig:
...
@@ -153,8 +155,6 @@ class PassConfig:
8: 1, # 1MB
8: 1, # 1MB
},
},
}, where key is the device capability"""
}, where key is the device capability"""
enable_qk_norm_rope_fusion
:
bool
=
False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
# TODO(luka) better pass enabling system.
...
@@ -834,23 +834,20 @@ class CompilationConfig:
...
@@ -834,23 +834,20 @@ class CompilationConfig:
func
if
isinstance
(
func
,
InductorPass
)
else
CallableInductorPass
(
func
)
func
if
isinstance
(
func
,
InductorPass
)
else
CallableInductorPass
(
func
)
)
)
if
self
.
pass_config
.
enable_qk_norm_rope_fusion
:
if
(
self
.
pass_config
.
enable_qk_norm_rope_fusion
and
"+rotary_embedding"
not
in
self
.
custom_ops
):
# TODO(zhuhaoran): support rope native forward match and remove this.
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self
.
custom_ops
.
append
(
"+rotary_embedding"
)
self
.
custom_ops
.
append
(
"+rotary_embedding"
)
if
self
.
pass_config
.
fuse_rope_kvcache
:
if
(
from
vllm._aiter_ops
import
rocm_aiter_ops
self
.
pass_config
.
fuse_rope_kvcache
and
"+rotary_embedding"
not
in
self
.
custom_ops
if
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
():
):
logger
.
warning
(
# TODO(Rohan138): support rope native forward match and remove this.
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
self
.
custom_ops
.
append
(
"+rotary_embedding"
)
)
self
.
pass_config
.
fuse_rope_kvcache
=
False
else
:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self
.
custom_ops
.
append
(
"+rotary_embedding"
)
if
(
if
(
is_torch_equal_or_newer
(
"2.9.0.dev"
)
is_torch_equal_or_newer
(
"2.9.0.dev"
)
...
...
vllm/config/vllm.py
View file @
f38f8c97
...
@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
...
@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
)
)
def
enable_rope_kvcache_fusion
(
cfg
:
"VllmConfig"
)
->
bool
:
"""Enable if rotary embedding custom op is active and
use_inductor_graph_partition is enabled.
"""
from
vllm._aiter_ops
import
rocm_aiter_ops
return
(
rocm_aiter_ops
.
is_enabled
()
and
cfg
.
compilation_config
.
is_custom_op_enabled
(
"rotary_embedding"
)
and
cfg
.
compilation_config
.
use_inductor_graph_partition
)
def
enable_norm_pad_fusion
(
cfg
:
"VllmConfig"
)
->
bool
:
def
enable_norm_pad_fusion
(
cfg
:
"VllmConfig"
)
->
bool
:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
from
vllm._aiter_ops
import
rocm_aiter_ops
return
(
return
(
envs
.
VLLM_ROCM_USE_AITER
rocm_aiter_ops
.
is_rmsnorm_enabled
()
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
and
not
rocm_aiter_ops
.
is_triton_gemm_enabled
()
and
envs
.
VLLM_ROCM_USE_AITER_TRITON_GEMM
and
cfg
.
model_config
is
not
None
and
cfg
.
model_config
is
not
None
and
cfg
.
model_config
.
get_hidden_size
()
==
2880
and
cfg
.
model_config
.
get_hidden_size
()
==
2880
)
)
...
@@ -149,6 +162,7 @@ OPTIMIZATION_LEVEL_00 = {
...
@@ -149,6 +162,7 @@ OPTIMIZATION_LEVEL_00 = {
"enable_sp"
:
False
,
"enable_sp"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_act_padding"
:
False
,
"fuse_act_padding"
:
False
,
"fuse_rope_kvcache"
:
False
,
},
},
"cudagraph_mode"
:
CUDAGraphMode
.
NONE
,
"cudagraph_mode"
:
CUDAGraphMode
.
NONE
,
"use_inductor_graph_partition"
:
False
,
"use_inductor_graph_partition"
:
False
,
...
@@ -167,6 +181,7 @@ OPTIMIZATION_LEVEL_01 = {
...
@@ -167,6 +181,7 @@ OPTIMIZATION_LEVEL_01 = {
"enable_sp"
:
False
,
"enable_sp"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_rope_kvcache"
:
enable_rope_kvcache_fusion
,
},
},
"cudagraph_mode"
:
CUDAGraphMode
.
PIECEWISE
,
"cudagraph_mode"
:
CUDAGraphMode
.
PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
"use_inductor_graph_partition"
:
False
,
...
@@ -185,6 +200,7 @@ OPTIMIZATION_LEVEL_02 = {
...
@@ -185,6 +200,7 @@ OPTIMIZATION_LEVEL_02 = {
"enable_sp"
:
IS_DENSE
,
"enable_sp"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_rope_kvcache"
:
enable_rope_kvcache_fusion
,
},
},
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
"use_inductor_graph_partition"
:
False
,
...
@@ -203,6 +219,7 @@ OPTIMIZATION_LEVEL_03 = {
...
@@ -203,6 +219,7 @@ OPTIMIZATION_LEVEL_03 = {
"enable_sp"
:
IS_DENSE
,
"enable_sp"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
"fuse_rope_kvcache"
:
enable_rope_kvcache_fusion
,
},
},
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
"use_inductor_graph_partition"
:
False
,
...
...
vllm/envs.py
View file @
f38f8c97
...
@@ -105,7 +105,7 @@ if TYPE_CHECKING:
...
@@ -105,7 +105,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA
:
bool
=
True
VLLM_ROCM_USE_AITER_MLA
:
bool
=
True
VLLM_ROCM_USE_AITER_MHA
:
bool
=
True
VLLM_ROCM_USE_AITER_MHA
:
bool
=
True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
:
bool
=
False
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
:
bool
=
False
VLLM_ROCM_USE_AITER_TRITON_ROPE
:
bool
=
Fals
e
VLLM_ROCM_USE_AITER_TRITON_ROPE
:
bool
=
Tru
e
VLLM_ROCM_USE_AITER_FP8BMM
:
bool
=
True
VLLM_ROCM_USE_AITER_FP8BMM
:
bool
=
True
VLLM_ROCM_USE_AITER_FP4BMM
:
bool
=
True
VLLM_ROCM_USE_AITER_FP4BMM
:
bool
=
True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
:
bool
=
False
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
:
bool
=
False
...
@@ -937,9 +937,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -937,9 +937,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)
),
),
# Whether to use aiter rope.
# Whether to use aiter rope.
# By default is
dis
abled.
# By default is
en
abled.
"VLLM_ROCM_USE_AITER_TRITON_ROPE"
:
lambda
:
(
"VLLM_ROCM_USE_AITER_TRITON_ROPE"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_TRITON_ROPE"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)
os
.
getenv
(
"VLLM_ROCM_USE_AITER_TRITON_ROPE"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)
),
),
# Whether to use aiter triton fp8 bmm kernel
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
# By default is enabled.
...
...
vllm/model_executor/layers/rotary_embedding/base.py
View file @
f38f8c97
...
@@ -47,15 +47,20 @@ class RotaryEmbeddingBase(CustomOp):
...
@@ -47,15 +47,20 @@ class RotaryEmbeddingBase(CustomOp):
if
not
hasattr
(
self
,
"use_flashinfer"
):
if
not
hasattr
(
self
,
"use_flashinfer"
):
self
.
use_flashinfer
=
False
self
.
use_flashinfer
=
False
self
.
use_aiter
=
(
self
.
enabled
()
and
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
()
)
if
self
.
use_aiter
:
self
.
rocm_aiter_triton_rotary_embedding
=
(
rocm_aiter_ops
.
get_triton_rotary_embedding_op
()
)
if
init_cache
:
if
init_cache
:
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
if
not
self
.
use_flashinfer
:
if
not
self
.
use_flashinfer
:
cache
=
cache
.
to
(
dtype
)
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
is_rocm_triton_rotary_embed_enabled
=
(
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
()
)
self
.
apply_rotary_emb
=
ApplyRotaryEmb
(
self
.
apply_rotary_emb
=
ApplyRotaryEmb
(
is_neox_style
=
self
.
is_neox_style
,
is_neox_style
=
self
.
is_neox_style
,
...
@@ -231,15 +236,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
...
@@ -231,15 +236,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
=
None
,
key
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
self
.
is_rocm_triton_rotary_embed_enabled
:
if
self
.
use_aiter
:
cos_sin_cache
=
self
.
_match_cos_sin_cache_dtype
(
query
)
cos_sin_cache
=
self
.
_match_cos_sin_cache_dtype
(
query
)
rocm_aiter_
ops
.
triton_rotary_embed
(
self
.
rocm_aiter_triton_rotary_embed
ding
(
positions
,
positions
,
query
,
query
,
key
,
key
,
cos_sin_cache
,
self
.
head_size
,
self
.
head_size
,
self
.
rotary_dim
,
cos_sin_cache
,
self
.
is_neox_style
,
self
.
is_neox_style
,
)
)
return
query
,
key
return
query
,
key
...
...
vllm/platforms/rocm.py
View file @
f38f8c97
...
@@ -494,6 +494,7 @@ class RocmPlatform(Platform):
...
@@ -494,6 +494,7 @@ class RocmPlatform(Platform):
use_aiter_rms_norm
=
rocm_aiter_ops
.
is_rmsnorm_enabled
()
use_aiter_rms_norm
=
rocm_aiter_ops
.
is_rmsnorm_enabled
()
use_aiter_fp8_linear
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
use_aiter_fp8_linear
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
use_aiter_fused_se
=
rocm_aiter_ops
.
is_fusion_moe_shared_experts_enabled
()
use_aiter_fused_se
=
rocm_aiter_ops
.
is_fusion_moe_shared_experts_enabled
()
use_aiter_triton_rope
=
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
()
if
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
if
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
# decode context parallel does not support full cudagraphs
# decode context parallel does not support full cudagraphs
...
@@ -558,6 +559,13 @@ class RocmPlatform(Platform):
...
@@ -558,6 +559,13 @@ class RocmPlatform(Platform):
and
"-grouped_topk"
not
in
compilation_config
.
custom_ops
and
"-grouped_topk"
not
in
compilation_config
.
custom_ops
):
):
compilation_config
.
custom_ops
.
append
(
"+grouped_topk"
)
compilation_config
.
custom_ops
.
append
(
"+grouped_topk"
)
# Enable rotary embedding when using AITER if its not disabled by user
if
(
use_aiter_triton_rope
and
"+rotary_embedding"
not
in
compilation_config
.
custom_ops
and
"-rotary_embedding"
not
in
compilation_config
.
custom_ops
):
compilation_config
.
custom_ops
.
append
(
"+rotary_embedding"
)
# Default dispatch to rocm's sparse_attn_indexer implementation
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config
.
custom_ops
.
append
(
"+sparse_attn_indexer"
)
compilation_config
.
custom_ops
.
append
(
"+sparse_attn_indexer"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment