Unverified Commit e54894fc authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Fix TestSiluMulGroupFp8QuantModel after W8A8 block linear refactor (#39799)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent bc2ae5a3
......@@ -10,7 +10,7 @@ import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from tests.utils import TestFP8Layer
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.act_quant_fusion import (
FUSED_OPS,
......@@ -157,12 +157,13 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
activation_quant_key=self.act_quant_key,
input_dtype=dtype,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
if not current_platform.is_fp8_fnuz():
kernel = self.w8a8_block_fp8_linear.kernel
orig_quant = kernel.quant_fp8
kernel.quant_fp8 = lambda *a, use_triton=False, **kw: orig_quant(
*a, use_triton=True, **kw
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
......@@ -174,6 +175,9 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
rocm_aiter_ops.get_group_quant_op()
if current_platform.is_fp8_fnuz()
else torch.ops.vllm.triton_per_token_group_quant_fp8.default,
]
def ops_in_model_after(self):
......@@ -324,7 +328,6 @@ def test_fusion_silu_and_mul_quant(
with set_current_vllm_config(config), monkeypatch.context() as m:
fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
......@@ -352,10 +355,16 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3
elif isinstance(model, TestSiluMulNvfp4QuantModel):
atol, rtol = 1e-1, 1e-1
elif isinstance(
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
):
elif isinstance(model, TestSiluMulGroupFp8QuantModel):
atol, rtol = 5e-2, 5e-2
elif isinstance(model, TestSiluMulBlockQuantModel):
if current_platform.is_rocm():
atol, rtol = 1e-3, 1e-3
else:
# CUDA fused kernel computes silu*mul in fp32 while the reference
# goes through bf16/fp16 storage, so group maxima (and thus scales)
# can shift by one FP8-e4m3 code (~1/8 relative step).
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment