"docker/Dockerfile.tpu" did not exist on "6a512a00dfa306762c2878bffc3a5664a758d105"
test_fusion.py 5.25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import pytest
import torch

import vllm.envs as envs
8
import vllm.plugins
9
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
10
                                     FusionPass)
11
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
13
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
                         VllmConfig)
14
from vllm.model_executor.layers.layernorm import RMSNorm
15
16
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape, QuantKey, ScaleDesc)
17
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18
    Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
19
from vllm.platforms import current_platform
20

21
from ..utils import override_cutlass_fp8_supported
22
23
from .backend import TestBackend

24
25
FP8_DTYPE = current_platform.fp8_dtype()

26
27
28

class TestModel(torch.nn.Module):

29
    def __init__(self, hidden_size: int, eps: float, static: bool,
30
                 cuda_force_torch: bool, *args, **kwargs):
31
        super().__init__(*args, **kwargs)
32
        self.cuda_force_torch = cuda_force_torch
33
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
34
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
35
        group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
36
37
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
        self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
38
39
40
41
        if static:
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        else:
            self.scale = [None for _ in range(2)]
42
43
44
45
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(2)
        ]
46
47
48
49
50
51

        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=static,
                act_quant_group_shape=group_shape,
            )
52
53

    def forward(self, x):
54
        resid = torch.sqrt(x)
55
56
        y = self.norm[0](x)

57
58
59
60
        x2 = self.fp8_linear.apply(y,
                                   self.w[0],
                                   self.wscale[0],
                                   input_scale=self.scale[0])
61
62
63
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

64
65
66
67
        x3 = self.fp8_linear.apply(y2,
                                   self.w[1],
                                   self.wscale[1],
                                   input_scale=self.scale[1])
68
69
70
        y3, resid = self.norm[2](x3, resid)  # use resid here
        return y3

71
72
73
74
75
76
77
78
79
    def ops_in_model_before(self):
        return [QUANT_OPS[self.key]]

    def ops_in_model_after(self):
        return [
            FUSED_OPS[FusedRMSQuantKey(self.key, False)],
            FUSED_OPS[FusedRMSQuantKey(self.key, True)]
        ]

80
81
82
83
84

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
85
@pytest.mark.parametrize("static", [True, False])
86
87
88
89
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
                         [True, False] if cutlass_fp8_supported() else [True])
90
91
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
92
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
93
                              cuda_force_torch):
94
    torch.set_default_device("cuda")
95
96
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
97
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
98

99
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
100
101
102
103
        level=CompilationLevel.PIECEWISE,
        custom_ops=["+rms_norm", "+quant_fp8"],
        pass_config=PassConfig(enable_fusion=True, enable_noop=True),
    ))
104
105
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
106
107
        noop_pass = NoOpEliminationPass(vllm_config)
        fusion_pass = FusionPass.instance(vllm_config)
108
109

        backend = TestBackend(noop_pass, fusion_pass)
110
        model = TestModel(hidden_size, eps, static, cuda_force_torch)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        result = model(x)

        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)

        # Higher tol for dynamic, even higher for bfloat16
        if static:
            ATOL, RTOL = (1e-3, 1e-3)
        elif dtype == torch.float16:
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

        torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

        # In pre-nodes, fp8 quant should be there and fused kernels should not
132
        backend.check_before_ops(model.ops_in_model_before())
133
134

        # In post-nodes, fused kernels should be there and fp8 quant should not
135
        backend.check_after_ops(model.ops_in_model_after())