Unverified Commit 16688b26 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize batch invariant with fused rms norm, 2.1% E2E latency improvement (#40413)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6fbec8ed
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from utils import skip_unsupported from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm, fused_add_rms_norm
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type DEVICE_TYPE = current_platform.device_type
...@@ -71,6 +71,93 @@ def test_rms_norm_batch_invariant_vs_standard( ...@@ -71,6 +71,93 @@ def test_rms_norm_batch_invariant_vs_standard(
) )
@skip_unsupported
@pytest.mark.parametrize("hidden_size", [512, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6])
def test_fused_add_rms_norm_batch_invariant_residual_path(
hidden_size: int,
dtype: torch.dtype,
eps: float,
):
"""
Test the batch-invariant fused residual-add + RMSNorm helper directly.
"""
device = torch.device(DEVICE_TYPE)
torch.manual_seed(42)
x_single = torch.randn(1, hidden_size, dtype=dtype, device=device)
residual_single = torch.randn(1, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
x_batch = torch.cat(
[
x_single,
torch.randn(3, hidden_size, dtype=dtype, device=device),
],
dim=0,
)
residual_batch = torch.cat(
[
residual_single,
torch.randn(3, hidden_size, dtype=dtype, device=device),
],
dim=0,
)
out_single, residual_out_single = fused_add_rms_norm(
x_single.clone(),
residual_single.clone(),
weight,
eps,
)
out_batch, residual_out_batch = fused_add_rms_norm(
x_batch.clone(),
residual_batch.clone(),
weight,
eps,
)
merged_single = x_single + residual_single
ref_out = triton_rms_norm(merged_single, weight, eps=eps)
torch.testing.assert_close(
residual_out_single,
merged_single,
rtol=0.0,
atol=0.0,
msg="Residual output should equal x + residual exactly",
)
torch.testing.assert_close(
residual_out_batch[:1],
merged_single,
rtol=0.0,
atol=0.0,
msg="Residual output should be batch invariant",
)
torch.testing.assert_close(
out_single,
out_batch[:1],
rtol=0.0,
atol=0.0,
msg="Fused add RMSNorm output should be batch invariant",
)
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(
out_single,
ref_out,
rtol=rtol,
atol=atol,
msg="Fused add RMSNorm output should stay numerically close to the "
"batch-invariant RMSNorm reference",
)
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("batch_size", [1, 16, 128]) @pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512]) @pytest.mark.parametrize("seq_len", [1, 32, 512])
......
...@@ -420,6 +420,7 @@ def rms_norm( ...@@ -420,6 +420,7 @@ def rms_norm(
def fused_add_rms_norm( def fused_add_rms_norm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None: ) -> None:
# Note: this func is batch invariant
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
......
...@@ -61,10 +61,6 @@ def fused_add_rms_norm( ...@@ -61,10 +61,6 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
residual, residual,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment