test_fusion.py 5.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pytest
import torch

7
import vllm.plugins
8
9
10
11
12
13
from vllm.compilation.fusion import (
    FUSED_OPS,
    QUANT_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
14
from vllm.compilation.noop_elimination import NoOpEliminationPass
15
from vllm.compilation.post_cleanup import PostCleanupPass
16
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
17
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.quantization.utils.quant_utils import (
19
20
21
22
    GroupShape,
    QuantKey,
    ScaleDesc,
)
23
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
24
25
26
27
    Fp8LinearOp,
    cutlass_fp8_supported,
    maybe_create_device_identity,
)
28
from vllm.platforms import current_platform
29

30
from ..utils import override_cutlass_fp8_supported
31
32
from .backend import TestBackend

33
34
FP8_DTYPE = current_platform.fp8_dtype()

35
36

class TestModel(torch.nn.Module):
37
38
39
40
41
42
43
44
45
    def __init__(
        self,
        hidden_size: int,
        eps: float,
        static: bool,
        cuda_force_torch: bool,
        *args,
        **kwargs,
    ):
46
        super().__init__(*args, **kwargs)
47
        self.cuda_force_torch = cuda_force_torch
48
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
49
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
50
        group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
51
52
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
        self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
53
54
55
56
        if static:
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        else:
            self.scale = [None for _ in range(2)]
57
58
59
60
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(2)
        ]
61
62
63
64
65
66

        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=static,
                act_quant_group_shape=group_shape,
            )
67
68

    def forward(self, x):
69
        resid = torch.sqrt(x)
70
71
        y = self.norm[0](x)

72
73
74
        x2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )
75
76
77
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

78
79
80
        x3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )
81
82
83
        y3, resid = self.norm[2](x3, resid)  # use resid here
        return y3

84
85
86
87
88
89
    def ops_in_model_before(self):
        return [QUANT_OPS[self.key]]

    def ops_in_model_after(self):
        return [
            FUSED_OPS[FusedRMSQuantKey(self.key, False)],
90
            FUSED_OPS[FusedRMSQuantKey(self.key, True)],
91
92
        ]

93
94

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
95
96
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
97
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
98
@pytest.mark.parametrize("static", [True, False])
99
100
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
101
102
103
104
105
106
107
108
109
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
    dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
):
110
    torch.set_default_device("cuda")
111
112
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
113
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
114

115
116
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
117
            mode=CompilationMode.VLLM_COMPILE,
118
119
120
121
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(enable_fusion=True, enable_noop=True),
        )
    )
122
123
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
124
        noop_pass = NoOpEliminationPass(vllm_config)
125
126
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
127

128
        backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
129
        model = TestModel(hidden_size, eps, static, cuda_force_torch)
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        result = model(x)

        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)

        # Higher tol for dynamic, even higher for bfloat16
        if static:
            ATOL, RTOL = (1e-3, 1e-3)
        elif dtype == torch.float16:
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

        torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

150
151
        assert fusion_pass.matched_count == 2

152
        # In pre-nodes, fp8 quant should be there and fused kernels should not
153
        backend.check_before_ops(model.ops_in_model_before())
154
155

        # In post-nodes, fused kernels should be there and fp8 quant should not
156
        backend.check_after_ops(model.ops_in_model_after())