test_fusion.py 5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import pytest
import torch

import vllm.envs as envs
8
import vllm.plugins
9
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
10
                                     FusionPass, GroupShape, QuantKey)
11
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
13
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
                         VllmConfig)
14
15
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
16
    CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
17
from vllm.platforms import current_platform
18
19
20

from .backend import TestBackend

21
22
FP8_DTYPE = current_platform.fp8_dtype()

23
24
25

class TestModel(torch.nn.Module):

26
27
    def __init__(self, hidden_size: int, eps: float, static: bool,
                 cutlass_fp8_enabled: bool, *args, **kwargs):
28
        super().__init__(*args, **kwargs)
29
        self.cutlass_fp8_enabled = cutlass_fp8_enabled
30
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
31
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
32
        group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
33
34
        self.key = QuantKey(dtype=FP8_DTYPE,
                            static=static,
35
                            group_shape=group_shape,
36
                            symmetric=True)
37
38
39
40
        if static:
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        else:
            self.scale = [None for _ in range(2)]
41
42
43
44
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(2)
        ]
45
46
47
        self.fp8_linear = Fp8LinearOp(
            cutlass_fp8_supported=cutlass_fp8_enabled,
            use_per_token_if_dynamic=True)
48
49

    def forward(self, x):
50
        resid = torch.sqrt(x)
51
52
        y = self.norm[0](x)

53
54
55
56
        x2 = self.fp8_linear.apply(y,
                                   self.w[0],
                                   self.wscale[0],
                                   input_scale=self.scale[0])
57
58
59
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

60
61
62
63
        x3 = self.fp8_linear.apply(y2,
                                   self.w[1],
                                   self.wscale[1],
                                   input_scale=self.scale[1])
64
65
66
        y3, resid = self.norm[2](x3, resid)  # use resid here
        return y3

67
68
69
70
71
72
73
74
75
    def ops_in_model_before(self):
        return [QUANT_OPS[self.key]]

    def ops_in_model_after(self):
        return [
            FUSED_OPS[FusedRMSQuantKey(self.key, False)],
            FUSED_OPS[FusedRMSQuantKey(self.key, True)]
        ]

76
77
78
79
80

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
81
@pytest.mark.parametrize("static", [True, False])
82
83
@pytest.mark.parametrize("cutlass_fp8_enabled",
                         [True, False] if CUTLASS_FP8_SUPPORTED else [False])
84
85
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
86
87
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
                              cutlass_fp8_enabled):
88
    torch.set_default_device("cuda")
89
90
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
91
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
92

93
94
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
95
    vllm_config.compilation_config.pass_config = \
96
        PassConfig(enable_fusion=True, enable_noop=True)
97
98
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
99
100
        noop_pass = NoOpEliminationPass(vllm_config)
        fusion_pass = FusionPass.instance(vllm_config)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        backend = TestBackend(noop_pass, fusion_pass)
        model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        result = model(x)

        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)

        # Higher tol for dynamic, even higher for bfloat16
        if static:
            ATOL, RTOL = (1e-3, 1e-3)
        elif dtype == torch.float16:
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

        torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

        # In pre-nodes, fp8 quant should be there and fused kernels should not
125
        backend.check_before_ops(model.ops_in_model_before())
126
127

        # In post-nodes, fused kernels should be there and fp8 quant should not
128
        backend.check_after_ops(model.ops_in_model_after())