Unverified Commit 2463f00f authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[AMD][CI][BugFix] Override normalize_e4m3fn_to_e4m3fnuz for fnuz machines in...


[AMD][CI][BugFix] Override normalize_e4m3fn_to_e4m3fnuz for fnuz machines in test_moe_layer_no_parallel (#40550)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent f946659f
...@@ -17,6 +17,7 @@ from typing import get_args ...@@ -17,6 +17,7 @@ from typing import get_args
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.quantization.utils.w8a8_utils
from tests.kernels.moe.modular_kernel_tools.parallel_utils import ( from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
ProcessGroupInfo, ProcessGroupInfo,
_set_vllm_config, _set_vllm_config,
...@@ -144,6 +145,24 @@ EPLB_SUPPORTED_QUANTS: list[str | None] = [None, "fp8"] ...@@ -144,6 +145,24 @@ EPLB_SUPPORTED_QUANTS: list[str | None] = [None, "fp8"]
EPLB_SUPPORTED_BACKENDS: list[str] = ["allgather_reducescatter"] EPLB_SUPPORTED_BACKENDS: list[str] = ["allgather_reducescatter"]
def mock_normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
):
return weight, weight_scale, input_scale
# Needed since weights will already be in e4m3fnuz format on platforms that
# use the fnuz fp8 format and the normalize_e4m3fn_to_e4m3fnuz() function
# is not being tested here.
# NOTE: The weights are quantized by moe_quantize_weights_2d in
# _quantize_fp8_halves.
# NOTE: Not able to use monkeypatch because of the spawned parallel workers.
def override_normalize_e4m3fn_to_e4m3fnuz():
vllm.model_executor.layers.quantization.utils.w8a8_utils.normalize_e4m3fn_to_e4m3fnuz = mock_normalize_e4m3fn_to_e4m3fnuz # noqa: E501
def maybe_roundup_layer_hidden_size( def maybe_roundup_layer_hidden_size(
hidden_size: int, hidden_size: int,
act_dtype: torch.dtype, act_dtype: torch.dtype,
...@@ -1471,6 +1490,11 @@ def test_moe_layer_no_parallel( ...@@ -1471,6 +1490,11 @@ def test_moe_layer_no_parallel(
if os.environ.get("VLLM_LOGGING_LEVEL") is None: if os.environ.get("VLLM_LOGGING_LEVEL") is None:
monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR") monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR")
# Needed since weights will already be in e4m3fnuz format and the
# normalize_e4m3fn_to_e4m3fnuz() function is not being tested here.
if current_platform.is_fp8_fnuz():
override_normalize_e4m3fn_to_e4m3fnuz()
test_config = MoETestConfig( test_config = MoETestConfig(
m, m,
n, n,
...@@ -1546,6 +1570,9 @@ def _parallel_worker( ...@@ -1546,6 +1570,9 @@ def _parallel_worker(
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
if current_platform.is_fp8_fnuz():
override_normalize_e4m3fn_to_e4m3fnuz()
for test_config in test_configs: for test_config in test_configs:
cc = vllm_config.compilation_config cc = vllm_config.compilation_config
if "from_forward_context" in cc.static_forward_context: if "from_forward_context" in cc.static_forward_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment