Unverified Commit 8853a50a authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][BugFix][AMD][FP8] Fix test_rms_norm so it runs correctly on ROCm (#32372)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent c5891b54
...@@ -14,9 +14,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -14,9 +14,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
) )
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
# Avoid combinatorial explosion with full Cartesian product # Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [ NUM_TOKENS_HIDDEN_SIZES = [
...@@ -61,14 +62,14 @@ def ref_dynamic_per_token_or_block_quant( ...@@ -61,14 +62,14 @@ def ref_dynamic_per_token_or_block_quant(
group_size: list[int] | None, group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None: if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn assert quant_dtype == current_platform.fp8_dtype()
# Norm # Norm
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant # Quant
if group_size is not None: if group_size is not None:
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == current_platform.fp8_dtype():
torch_out, scales = per_token_group_quant_fp8( torch_out, scales = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False torch_out, group_size=group_size[1], use_ue8m0=False
) )
...@@ -78,7 +79,7 @@ def ref_dynamic_per_token_or_block_quant( ...@@ -78,7 +79,7 @@ def ref_dynamic_per_token_or_block_quant(
torch_out, group_size=group_size[1] torch_out, group_size=group_size[1]
) )
else: else:
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == current_platform.fp8_dtype():
torch_out, scales = ops.scaled_fp8_quant( torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
) )
...@@ -162,6 +163,7 @@ def test_rms_norm( ...@@ -162,6 +163,7 @@ def test_rms_norm(
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device)
if group_size is not None and hidden_size % group_size[1] != 0: if group_size is not None and hidden_size % group_size[1] != 0:
# skip # skip
...@@ -171,7 +173,7 @@ def test_rms_norm( ...@@ -171,7 +173,7 @@ def test_rms_norm(
# blockwise baseline doesn't support scale_ub # blockwise baseline doesn't support scale_ub
return return
if has_scale_ub and quant_dtype != torch.float8_e4m3fn: if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
# skip # skip
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment