test_fusion.py 5.19 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
import pytest
import torch
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
8
import vllm.plugins
9
10
11
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
                                     FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
12
13
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
14
15
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
16
    CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
17
18
19
20
21
22

from .backend import TestBackend


class TestModel(torch.nn.Module):

23
24
    def __init__(self, hidden_size: int, eps: float, static: bool,
                 cutlass_fp8_enabled: bool, *args, **kwargs):
25
        super().__init__(*args, **kwargs)
26
        self.cutlass_fp8_enabled = cutlass_fp8_enabled
27
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
28
29
30
31
32
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        if static:
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
        else:
            self.scale = [None for _ in range(2)]
33
34
35
36
37
38
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(2)
        ]

    def forward(self, x):
39
        resid = torch.sqrt(x)
40
41
        y = self.norm[0](x)

42
43
44
45
        x2 = apply_fp8_linear(y,
                              self.w[0],
                              self.wscale[0],
                              self.scale[0],
46
47
                              use_per_token_if_dynamic=True,
                              cutlass_fp8_supported=self.cutlass_fp8_enabled)
48
49
50
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

51
52
53
54
        x3 = apply_fp8_linear(y2,
                              self.w[1],
                              self.wscale[1],
                              self.scale[1],
55
56
                              use_per_token_if_dynamic=True,
                              cutlass_fp8_supported=self.cutlass_fp8_enabled)
57
58
59
60
61
62
63
64
        y3, resid = self.norm[2](x3, resid)  # use resid here
        return y3


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
65
@pytest.mark.parametrize("static", [True, False])
66
67
@pytest.mark.parametrize("cutlass_fp8_enabled",
                         [True, False] if CUTLASS_FP8_SUPPORTED else [False])
68
69
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
                    reason="Only test on CUDA")
70
71
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
                              cutlass_fp8_enabled):
72
    torch.set_default_device("cuda")
73
74
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
75
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
        config = CompilationConfig.PassConfig(enable_fusion=True,
                                              enable_noop=True)
        noop_pass = NoOpEliminationPass(config)
        fusion_pass = FusionPass.instance(config)

        backend = TestBackend(noop_pass, fusion_pass)
        model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

        result = model(x)

        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)

        # Higher tol for dynamic, even higher for bfloat16
        if static:
            ATOL, RTOL = (1e-3, 1e-3)
        elif dtype == torch.float16:
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

        torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

        # Check substitution worked
        pre_nodes = backend.graph_pre_pass.nodes
        post_nodes = backend.graph_post_pass.nodes

        # static is per-tensor, dynamic is per-token
        key = QuantKey(dtype=FP8_DTYPE,
                       static=static,
                       per_tensor=static,
                       symmetric=True)
        rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
        add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
        fp8_quant = QUANT_OPS[key]

        # In pre-nodes, fp8 quant should be there and fused kernels should not
        assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
        assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
        find_auto_fn(pre_nodes, fp8_quant)

        # In post-nodes, fused kernels should be there and fp8 quant should not
        find_auto_fn(post_nodes, rms_quant)
        find_auto_fn(post_nodes, add_rms_quant)
        assert find_auto_fn_maybe(post_nodes, fp8_quant) is None