test_functionalization.py 4.37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import pytest
import torch

import vllm.envs as envs
from vllm import LLM, SamplingParams
9
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
10
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
11
from vllm.compilation.fusion import FUSED_OPS, FusionPass
12
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
13
from vllm.compilation.noop_elimination import NoOpEliminationPass
14
from vllm.config import CompilationConfig, PassConfig, VllmConfig
15
16
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from .backend import TestBackend

OPS_IN_MODEL = [
    torch.ops._C.rotary_embedding.default,
    torch.ops._C.fused_add_rms_norm.default,
]

RMS_OP = torch.ops._C.rms_norm.default

RMS_QUANT_OPS = {
    "static_fp8": [
        torch.ops._C.rms_norm_static_fp8_quant.default,
        torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
    ],
}

34
35
36
SILU_MUL_OP = torch.ops._C.silu_and_mul.default

SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
37
38
39
40
41
42
43
44
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


45
46
47
48
49
@pytest.mark.parametrize(
    "model, quant_key",
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
     ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
      kFp8DynamicTokenSym)])
50
51
52
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
                    reason="Only test on CUDA")
53
54
def test_fix_functionalization(model: str, quant_key: QuantKey,
                               do_fusion: bool):
55
56
    torch.set_default_device("cuda")

57
    vllm_config = VllmConfig()
58
59
    vllm_config.compilation_config = CompilationConfig(
        pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
60
61
    noop_pass = NoOpEliminationPass(vllm_config)
    fusion_pass = FusionPass.instance(vllm_config)
62
    act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
63

64
65
    passes = [noop_pass, fusion_pass, act_quant_fusion_pass
              ] if do_fusion else [noop_pass]
66
    func_pass = FixFunctionalizationPass(vllm_config)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    backend_func = TestBackend(*passes, func_pass)
    backend_no_func = TestBackend(*passes)

    # instantiate a full engine and manually compile the model 2x
    # (with and without FixFunctionalizationPass)
    llm = LLM(model=model, enforce_eager=True)
    model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
    orig_model = model_runner.model
    # TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
    # Can only do that by using the decorator but then we'd have to instantiate
    # 2 LLM instances.

    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
    model_runner.model = torch.compile(orig_model,
                                       fullgraph=True,
                                       backend=backend_func)
    gen_func = llm.generate(prompts, sampling_params)

    model_runner.model = torch.compile(orig_model,
                                       fullgraph=True,
                                       backend=backend_no_func)
88

89
90
91
92
93
94
95
    gen_no_func = llm.generate(prompts, sampling_params)

    for output_func, output_no_func in zip(gen_func, gen_no_func):
        assert output_func.outputs[0].text == output_no_func.outputs[0].text

    # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
    # and replaced by fused quantized ops in RMS_QUANT_OPS.
96
97
    rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
               ] if do_fusion else [RMS_OP]
98
99
100
101
102
103
    silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
        quant_key == kFp8StaticTensorSym else [
        SILU_MUL_OP
    ]

    ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
104
105
106
107
108
109
110
111
112
113
114
115
116

    for op in ops:
        find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
        assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
                                  op) is None  # noqa: E501

    # make sure the ops were all de-functionalized
    found = dict()
    for node in backend_func.graph_post_pass.nodes:
        for op in ops:
            if is_func(node, op):
                found[op] = True
    assert all(found[op] for op in ops)