Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c195d43
Unverified
Commit
7c195d43
authored
Sep 10, 2025
by
vllmellm
Committed by
GitHub
Sep 10, 2025
Browse files
[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
0ae43dbf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
36 deletions
+108
-36
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+22
-17
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+71
-16
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+15
-3
No files found.
tests/model_executor/test_enabled_custom_ops.py
View file @
7c195d43
...
@@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
...
@@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax
)
vllm_topk_softmax
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
)
is_rocm_aiter_moe_enabled
)
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
RMSNorm
,
RMSNorm
,
dispatch_
cuda
_rmsnorm_func
,
fused_add_rms_norm
,
rms_norm
,
dispatch_
rocm
_rmsnorm_func
,
rocm_aiter_
fused_add_rms_norm
,
rocm_aiter_
rms_norm
)
fused_add_rms_norm
,
rms_norm
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
cutlass_scaled_mm
,
dispatch_w8a8_blockscale_func
,
w8a8_block_fp8_matmul
)
cutlass_scaled_mm
,
dispatch_w8a8_blockscale_func
,
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
RMS_NORM_SUPPORTED_DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# Registered subclass for test
# Registered subclass for test
@
CustomOp
.
register
(
"relu3"
)
@
CustomOp
.
register
(
"relu3"
)
...
@@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
...
@@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter_norm"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter_norm"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"AITER is a feature exclusive for ROCm"
)
reason
=
"AITER is a feature exclusive for ROCm"
)
def
test_rms_norm_dispatch
(
add_residual
:
bool
,
use_rocm_aiter
:
str
,
def
test_rms_norm_dispatch
(
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
use_rocm_aiter_norm
:
str
,
monkeypatch
):
use_rocm_aiter
:
str
,
use_rocm_aiter_norm
:
str
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
use_rocm_aiter
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
use_rocm_aiter
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
use_rocm_aiter_norm
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
use_rocm_aiter_norm
)
rms_norm_func
=
dispatch_
cuda
_rmsnorm_func
(
add_residual
)
rms_norm_func
=
dispatch_
rocm
_rmsnorm_func
(
add_residual
,
dtype
)
if
not
add_residual
:
should_use_rocm_aiter
=
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
\
if
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
and
int
(
use_rocm_aiter_norm
)
and
dtype
in
RMS_NORM_SUPPORTED_DTYPES
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_rms_norm
if
add_residual
and
should_use_rocm_aiter
:
assert
rms_norm_func
==
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm2d_fwd_with_add
elif
should_use_rocm_aiter
:
assert
rms_norm_func
==
torch
.
ops
.
vllm
.
rocm_aiter_rms_norm
elif
add_residual
:
assert
rms_norm_func
==
fused_add_rms_norm
else
:
else
:
assert
rms_norm_func
==
rms_norm
assert
rms_norm_func
==
rms_norm
elif
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_fused_add_rms_norm
else
:
assert
rms_norm_func
==
fused_add_rms_norm
vllm/model_executor/layers/layernorm.py
View file @
7c195d43
...
@@ -9,11 +9,11 @@ import torch.nn as nn
...
@@ -9,11 +9,11 @@ import torch.nn as nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
return
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
and
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER
...
@@ -43,7 +43,7 @@ def fused_add_rms_norm(
...
@@ -43,7 +43,7 @@ def fused_add_rms_norm(
return
x
,
residual
return
x
,
residual
def
rocm_aiter_rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rocm_aiter_rms_norm
_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
variance_epsilon
:
float
)
->
torch
.
Tensor
:
import
aiter
as
rocm_aiter
import
aiter
as
rocm_aiter
if
x
.
dim
()
>
2
:
if
x
.
dim
()
>
2
:
...
@@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
...
@@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return
rocm_aiter
.
rms_norm
(
x
,
weight
,
variance_epsilon
)
return
rocm_aiter
.
rms_norm
(
x
,
weight
,
variance_epsilon
)
def
rocm_aiter_
fused_add_rms_norm
(
def
rocm_aiter_
rmsnorm2d_fwd_with_add_impl
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
...
@@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
return
output
,
residual_out
return
output
,
residual_out
def
dispatch_cuda_rmsnorm_func
(
add_residual
:
bool
):
def
rocm_aiter_rms_norm_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
if
add_residual
:
variance_epsilon
:
float
)
->
torch
.
Tensor
:
if
is_rocm_aiter_rmsnorm_enabled
():
return
torch
.
empty_like
(
x
)
return
rocm_aiter_fused_add_rms_norm
return
fused_add_rms_norm
def
rocm_aiter_rmsnorm2d_fwd_with_add_fake
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
op_name
=
"rocm_aiter_rms_norm"
,
op_func
=
rocm_aiter_rms_norm_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rms_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_func
=
rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
dispatch_rocm_rmsnorm_func
(
with_fused_add
:
bool
,
dtype
:
torch
.
dtype
):
use_aiter
=
is_rocm_aiter_rmsnorm_enabled
()
and
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
if
use_aiter
and
with_fused_add
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm2d_fwd_with_add
if
use_aiter
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rms_norm
if
is_rocm_aiter_rmsnorm_enabled
():
# fall back to CUDA implementation
return
rocm_aiter_rms_norm
if
with_fused_add
:
return
fused_add_rms_norm
return
rms_norm
return
rms_norm
...
@@ -114,6 +148,13 @@ class RMSNorm(CustomOp):
...
@@ -114,6 +148,13 @@ class RMSNorm(CustomOp):
self
.
weight
=
torch
.
ones
(
hidden_size
)
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
self
.
has_weight
:
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
weight_dtype
=
self
.
weight
.
data
.
dtype
if
current_platform
.
is_rocm
():
self
.
rocm_norm_func
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
False
,
dtype
=
weight_dtype
)
self
.
rocm_norm_func_with_add
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
True
,
dtype
=
weight_dtype
)
def
forward_native
(
def
forward_native
(
self
,
self
,
...
@@ -162,13 +203,27 @@ class RMSNorm(CustomOp):
...
@@ -162,13 +203,27 @@ class RMSNorm(CustomOp):
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
add_residual
=
residual
is
not
None
norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
add_residual
:
return
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
if
add_residual
:
if
add_residual
:
return
norm_func
(
x
,
residual
,
self
.
weight
.
data
,
return
self
.
rocm_norm_func_with_add
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
self
.
variance_epsilon
)
else
:
else
:
return
norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
self
.
rocm_norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_xpu
(
def
forward_xpu
(
self
,
self
,
...
...
vllm/platforms/rocm.py
View file @
7c195d43
...
@@ -322,23 +322,35 @@ class RocmPlatform(Platform):
...
@@ -322,23 +322,35 @@ class RocmPlatform(Platform):
@
classmethod
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
from
vllm.config.compilation
import
CUDAGraphMode
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
compilation_config
=
vllm_config
.
compilation_config
parallel_config
=
vllm_config
.
parallel_config
is_eager_execution
=
compilation_config
==
CUDAGraphMode
.
NONE
use_v1
=
envs
.
VLLM_USE_V1
use_aiter_rms_norm
=
envs
.
VLLM_ROCM_USE_AITER
and
\
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
if
cache_config
and
cache_config
.
block_size
is
None
:
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
cache_config
.
block_size
=
16
parallel_config
=
vllm_config
.
parallel_config
if
parallel_config
.
worker_cls
==
"auto"
:
if
parallel_config
.
worker_cls
==
"auto"
:
if
vllm_config
.
speculative_config
:
if
vllm_config
.
speculative_config
:
if
not
envs
.
VLLM_USE_V
1
:
if
not
use_v
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Speculative decoding is not supported on vLLM V0."
)
"Speculative decoding is not supported on vLLM V0."
)
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
else
:
else
:
if
envs
.
VLLM_USE_V
1
:
if
use_v
1
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.gpu_worker.Worker"
"vllm.v1.worker.gpu_worker.Worker"
else
:
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if
use_v1
and
use_aiter_rms_norm
and
not
is_eager_execution
:
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
@
classmethod
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment