Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec870fba
Unverified
Commit
ec870fba
authored
Mar 22, 2025
by
TJian
Committed by
GitHub
Mar 21, 2025
Browse files
[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
df143026
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
173 additions
and
29 deletions
+173
-29
Dockerfile.rocm_base
Dockerfile.rocm_base
+15
-1
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+28
-1
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+40
-10
vllm/envs.py
vllm/envs.py
+13
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+77
-17
No files found.
Dockerfile.rocm_base
View file @
ec870fba
...
@@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
...
@@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base
FROM ${BASE_IMAGE} AS base
...
@@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
...
@@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
pip install /install/*.whl
ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
ARG BASE_IMAGE
ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
ARG RCCL_BRANCH
ARG RCCL_REPO
ARG RCCL_REPO
...
@@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
...
@@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
tests/model_executor/test_enabled_custom_ops.py
View file @
ec870fba
...
@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
...
@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
ReLUSquaredActivation
,
ReLUSquaredActivation
,
SiluAndMul
)
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
(
RMSNorm
,
dispatch_cuda_rmsnorm_func
,
fused_add_rms_norm
,
rms_norm
,
rocm_aiter_fused_add_rms_norm
,
rocm_aiter_rms_norm
)
from
vllm.platforms
import
current_platform
# Registered subclass for test
# Registered subclass for test
...
@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
...
@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
custom_ops
=
env
.
split
(
","
)))
custom_ops
=
env
.
split
(
","
)))
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
RMSNorm
(
1024
).
enabled
()
RMSNorm
(
1024
).
enabled
()
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter_norm"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"AITER is a feature exclusive for ROCm"
)
def
test_rms_norm_dispatch
(
add_residual
:
bool
,
use_rocm_aiter
:
str
,
use_rocm_aiter_norm
:
str
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
use_rocm_aiter
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
use_rocm_aiter_norm
)
rms_norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
not
add_residual
:
if
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_rms_norm
else
:
assert
rms_norm_func
==
rms_norm
elif
current_platform
.
is_rocm
()
and
int
(
use_rocm_aiter
)
and
int
(
use_rocm_aiter_norm
):
assert
rms_norm_func
==
rocm_aiter_fused_add_rms_norm
else
:
assert
rms_norm_func
==
fused_add_rms_norm
tests/models/decoder_only/language/test_models.py
View file @
ec870fba
...
@@ -3,7 +3,11 @@
...
@@ -3,7 +3,11 @@
Run `pytest tests/models/test_models.py`.
Run `pytest tests/models/test_models.py`.
"""
"""
import
pytest
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
...
@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close
...
@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close
# https://github.com/vllm-project/vllm/issues/14524
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0
=
[
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
]
REQUIRES_V0
=
[
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
]
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST
=
[
"meta-llama/Llama-3.2-1B-Instruct"
,
"openbmb/MiniCPM3-4B"
,
"Qwen/Qwen-7B"
,
"Qwen/Qwen2.5-0.5B-Instruct"
,
"ehristoforu/Falcon3-MoE-2x7B-Insruct"
,
]
# @maybe_test_rocm_aiter
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model"
,
"model"
,
[
[
...
@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
...
@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
@
pytest
.
mark
.
parametrize
(
hf_runner
,
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
])
vllm_runner
,
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
example_prompts
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
model
:
str
,
use_rocm_aiter
:
bool
,
monkeypatch
)
->
None
:
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
monkeypatch
,
)
->
None
:
if
model
in
REQUIRES_V0
:
if
model
in
REQUIRES_V0
:
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
if
use_rocm_aiter
and
(
model
in
AITER_MODEL_LIST
):
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
elif
use_rocm_aiter
and
model
not
in
AITER_MODEL_LIST
:
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
pytest
.
skip
(
f
"Skipping '
{
model
}
' model test with AITER kernel."
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
if
model
.
startswith
(
"THUDM/chatglm3"
):
if
model
.
startswith
(
"THUDM/chatglm3"
):
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
...
@@ -100,3 +123,10 @@ def test_models(
...
@@ -100,3 +123,10 @@ def test_models(
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
)
)
if
use_rocm_aiter
:
# this is to ensure that vllm engine
# has deallocated the memory before running the next
# unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely
# before running the next test case
torch
.
cuda
.
synchronize
()
vllm/envs.py
View file @
ec870fba
...
@@ -75,6 +75,8 @@ if TYPE_CHECKING:
...
@@ -75,6 +75,8 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
True
VLLM_USE_V1
:
bool
=
True
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
...
@@ -528,6 +530,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -528,6 +530,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V1"
:
"VLLM_USE_V1"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"1"
))),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Pad the fp8 weights to 256 bytes for ROCm
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING"
:
"VLLM_ROCM_FP8_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
...
...
vllm/model_executor/layers/layernorm.py
View file @
ec870fba
...
@@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union
...
@@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
and
envs
.
VLLM_ROCM_USE_AITER
def
rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
out
,
x
,
weight
,
variance_epsilon
,
)
return
out
def
fused_add_rms_norm
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
ops
.
fused_add_rms_norm
(
x
,
residual
,
weight
,
variance_epsilon
,
)
return
x
,
residual
def
rocm_aiter_rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
import
aiter
as
rocm_aiter
return
rocm_aiter
.
rms_norm
(
x
,
weight
,
variance_epsilon
)
def
rocm_aiter_fused_add_rms_norm
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
import
aiter
as
rocm_aiter
# Assuming the correct signature for rmsnorm2d_fwd_with_add
rocm_aiter
.
rmsnorm2d_fwd_with_add
(
x
,
# output
x
,
# input
residual
,
# residual input
residual
,
# residual output
weight
,
variance_epsilon
,
)
return
x
,
residual
def
dispatch_cuda_rmsnorm_func
(
add_residual
:
bool
):
if
add_residual
:
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_fused_add_rms_norm
return
fused_add_rms_norm
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_rms_norm
return
rms_norm
@
CustomOp
.
register
(
"rms_norm"
)
@
CustomOp
.
register
(
"rms_norm"
)
...
@@ -81,24 +151,14 @@ class RMSNorm(CustomOp):
...
@@ -81,24 +151,14 @@ class RMSNorm(CustomOp):
if
self
.
variance_size_override
is
not
None
:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
from
vllm
import
_custom_ops
as
ops
add_residual
=
residual
is
not
None
norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
residual
is
not
None
:
if
add_residual
:
ops
.
fused_add_rms_norm
(
return
norm_func
(
x
,
residual
,
self
.
weight
.
data
,
x
,
self
.
variance_epsilon
)
residual
,
else
:
self
.
weight
.
data
,
return
norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
out
def
forward_hpu
(
def
forward_hpu
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment