Unverified Commit ec870fba authored by TJian's avatar TJian Committed by GitHub
Browse files

[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent df143026
...@@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ...@@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb" ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git" ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base FROM ${BASE_IMAGE} AS base
...@@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ ...@@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl pip install /install/*.whl
ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
ARG BASE_IMAGE ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
ARG LEGACY_HIPBLASLT_OPTION ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH ARG RCCL_BRANCH
ARG RCCL_REPO ARG RCCL_REPO
...@@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ ...@@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
...@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.platforms import current_platform
# Registered subclass for test # Registered subclass for test
...@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str): ...@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(","))) custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if not add_residual:
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rms_norm
else:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
else:
assert rms_norm_func == fused_add_rms_norm
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
Run `pytest tests/models/test_models.py`. Run `pytest tests/models/test_models.py`.
""" """
import pytest import pytest
import torch
from vllm.platforms import current_platform
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
...@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close ...@@ -13,7 +17,21 @@ from ...utils import check_logprobs_close
# https://github.com/vllm-project/vllm/issues/14524 # https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
"meta-llama/Llama-3.2-1B-Instruct",
"openbmb/MiniCPM3-4B",
"Qwen/Qwen-7B",
"Qwen/Qwen2.5-0.5B-Instruct",
"ehristoforu/Falcon3-MoE-2x7B-Insruct",
]
# @maybe_test_rocm_aiter
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
...@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] ...@@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models( @pytest.mark.parametrize(
hf_runner, "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
vllm_runner, def test_models(hf_runner, vllm_runner, example_prompts, model: str,
example_prompts, dtype: str, max_tokens: int, num_logprobs: int,
model: str, use_rocm_aiter: bool, monkeypatch) -> None:
dtype: str,
max_tokens: int,
num_logprobs: int,
monkeypatch,
) -> None:
if model in REQUIRES_V0: if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_USE_V1", "0")
if use_rocm_aiter and (model in AITER_MODEL_LIST):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
elif use_rocm_aiter and model not in AITER_MODEL_LIST:
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
if model.startswith("THUDM/chatglm3"): if model.startswith("THUDM/chatglm3"):
hf_model.model.get_output_embeddings = lambda: \ hf_model.model.get_output_embeddings = lambda: \
...@@ -100,3 +123,10 @@ def test_models( ...@@ -100,3 +123,10 @@ def test_models(
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next
# unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely
# before running the next test case
torch.cuda.synchronize()
...@@ -75,6 +75,8 @@ if TYPE_CHECKING: ...@@ -75,6 +75,8 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
...@@ -528,6 +530,17 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -528,6 +530,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V1": "VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm # Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
......
...@@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union ...@@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
ops.fused_add_rms_norm(
x,
residual,
weight,
variance_epsilon,
)
return x, residual
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
# Assuming the correct signature for rmsnorm2d_fwd_with_add
rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
x, # input
residual, # residual input
residual, # residual output
weight,
variance_epsilon,
)
return x, residual
def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
return rms_norm
@CustomOp.register("rms_norm") @CustomOp.register("rms_norm")
...@@ -81,24 +151,14 @@ class RMSNorm(CustomOp): ...@@ -81,24 +151,14 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
from vllm import _custom_ops as ops add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if residual is not None: if add_residual:
ops.fused_add_rms_norm( return norm_func(x, residual, self.weight.data,
x, self.variance_epsilon)
residual, else:
self.weight.data, return norm_func(x, self.weight.data, self.variance_epsilon)
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def forward_hpu( def forward_hpu(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment