test_fusion.py 4.97 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import pytest
import torch

import vllm.envs as envs
7
import vllm.plugins
8
9
10
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
                                     FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
11
12
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
13
14
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
15
    CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
16
from vllm.platforms import current_platform
17
18
19

from .backend import TestBackend

20
21
FP8_DTYPE = current_platform.fp8_dtype()

22
23
24

class TestModel(torch.nn.Module):

25
26
    def __init__(self, hidden_size: int, eps: float, static: bool,
                 cutlass_fp8_enabled: bool, *args, **kwargs):
27
        super().__init__(*args, **kwargs)
28
        self.cutlass_fp8_enabled = cutlass_fp8_enabled
29
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
30
31
32
33
34
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        if static:
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        else:
            self.scale = [None for _ in range(2)]
35
36
37
38
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(2)
        ]
39
40
41
        self.fp8_linear = Fp8LinearOp(
            cutlass_fp8_supported=cutlass_fp8_enabled,
            use_per_token_if_dynamic=True)
42
43

    def forward(self, x):
44
        resid = torch.sqrt(x)
45
46
        y = self.norm[0](x)

47
        x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
48
49
50
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

51
52
        x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
                                   self.scale[1])
53
54
55
56
57
58
59
60
        y3, resid = self.norm[2](x3, resid)  # use resid here
        return y3


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
61
@pytest.mark.parametrize("static", [True, False])
62
63
@pytest.mark.parametrize("cutlass_fp8_enabled",
                         [True, False] if CUTLASS_FP8_SUPPORTED else [False])
64
65
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
66
67
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
                              cutlass_fp8_enabled):
68
    torch.set_default_device("cuda")
69
70
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
71
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
        config = CompilationConfig.PassConfig(enable_fusion=True,
                                              enable_noop=True)
        noop_pass = NoOpEliminationPass(config)
        fusion_pass = FusionPass.instance(config)

        backend = TestBackend(noop_pass, fusion_pass)
        model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        result = model(x)

        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)

        # Higher tol for dynamic, even higher for bfloat16
        if static:
            ATOL, RTOL = (1e-3, 1e-3)
        elif dtype == torch.float16:
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

        torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

        # Check substitution worked
        pre_nodes = backend.graph_pre_pass.nodes
        post_nodes = backend.graph_post_pass.nodes

        # static is per-tensor, dynamic is per-token
        key = QuantKey(dtype=FP8_DTYPE,
                       static=static,
                       per_tensor=static,
                       symmetric=True)
        rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
        add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
        fp8_quant = QUANT_OPS[key]

        # In pre-nodes, fp8 quant should be there and fused kernels should not
        assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
        assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
        find_auto_fn(pre_nodes, fp8_quant)

        # In post-nodes, fused kernels should be there and fp8 quant should not
        find_auto_fn(post_nodes, rms_quant)
        find_auto_fn(post_nodes, add_rms_quant)
        assert find_auto_fn_maybe(post_nodes, fp8_quant) is None