Unverified Commit 1e0f917b authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Fix logprob divergence for TitanML/tiny-mixtral under AITER rms_norm (#36101)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent c174d54f
...@@ -126,6 +126,10 @@ def test_models( ...@@ -126,6 +126,10 @@ def test_models(
if use_rocm_aiter and (model in AITER_MODEL_LIST): if use_rocm_aiter and (model in AITER_MODEL_LIST):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if model == "TitanML/tiny-mixtral":
# Untrained model: near-uniform logits make argmax sensitive to
# AITER's bfloat16 rounding error in plain rms_norm.
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "0")
elif use_rocm_aiter and model not in AITER_MODEL_LIST: elif use_rocm_aiter and model not in AITER_MODEL_LIST:
# Skip model that are not using AITER tests. # Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be # When more AITER kernels are added, this list will not be
......
...@@ -1602,6 +1602,41 @@ def override_cutlass_fp8_supported(value: bool): ...@@ -1602,6 +1602,41 @@ def override_cutlass_fp8_supported(value: bool):
yield yield
def disable_aiter_plain_rmsnorm(monkeypatch) -> None:
"""Patch dispatch_rocm_rmsnorm_func so the plain (non-fused) rms_norm path
always uses the native float32 kernel for the duration of a test.
The fused path (rms_norm2d_with_add, selected when with_fused_add=True) is
left on AITER -- only the plain path is redirected to native.
AITER's plain rms_norm accumulates variance in bfloat16 (~1 ULP/call),
which drifts the KV cache over many decode steps. This drift is irrelevant
for a trained model (rank-1/rank-2 gap ~1-3 nats >> 1 ULP), but breaks
logprob comparison tests with randomly-initialised models like
TitanML/tiny-mixtral whose rank-1/rank-2 gap is only O(1/sqrt(V)) ~0.006
nats -- smaller than the accumulated per-step error.
"""
import torch
import vllm.model_executor.layers.layernorm as _ln_mod
from vllm.model_executor.layers.layernorm import rms_norm as _native
_orig = _ln_mod.dispatch_rocm_rmsnorm_func
def _native_plain(
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
if (
use_aiter
and not with_fused_add
and dtype in (torch.float16, torch.bfloat16)
):
return _native
return _orig(with_fused_add, dtype, use_aiter)
monkeypatch.setattr(_ln_mod, "dispatch_rocm_rmsnorm_func", _native_plain)
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
""" """
Generate prompts which a bunch of assignments, Generate prompts which a bunch of assignments,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment