test_functionalization.py 4.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import pytest
import torch

import vllm.envs as envs
from vllm import LLM, SamplingParams
9
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
10
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
11
12
13
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
                                     kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
14
from vllm.compilation.noop_elimination import NoOpEliminationPass
15
from vllm.config import CompilationConfig, PassConfig, VllmConfig
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

from .backend import TestBackend

OPS_IN_MODEL = [
    torch.ops._C.rotary_embedding.default,
    torch.ops._C.fused_add_rms_norm.default,
]

RMS_OP = torch.ops._C.rms_norm.default

RMS_QUANT_OPS = {
    "static_fp8": [
        torch.ops._C.rms_norm_static_fp8_quant.default,
        torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
    ],
}

33
34
35
SILU_MUL_OP = torch.ops._C.silu_and_mul.default

SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
36
37
38
39
40
41
42
43
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


44
45
46
47
48
@pytest.mark.parametrize(
    "model, quant_key",
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
     ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
      kFp8DynamicTokenSym)])
49
50
51
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
                    reason="Only test on CUDA")
52
53
def test_fix_functionalization(model: str, quant_key: QuantKey,
                               do_fusion: bool):
54
55
    torch.set_default_device("cuda")

56
    vllm_config = VllmConfig()
57
58
    vllm_config.compilation_config = CompilationConfig(
        pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
59
60
    noop_pass = NoOpEliminationPass(vllm_config)
    fusion_pass = FusionPass.instance(vllm_config)
61
    act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
62

63
64
    passes = [noop_pass, fusion_pass, act_quant_fusion_pass
              ] if do_fusion else [noop_pass]
65
    func_pass = FixFunctionalizationPass(vllm_config)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    backend_func = TestBackend(*passes, func_pass)
    backend_no_func = TestBackend(*passes)

    # instantiate a full engine and manually compile the model 2x
    # (with and without FixFunctionalizationPass)
    llm = LLM(model=model, enforce_eager=True)
    model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
    orig_model = model_runner.model
    # TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
    # Can only do that by using the decorator but then we'd have to instantiate
    # 2 LLM instances.

    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
    model_runner.model = torch.compile(orig_model,
                                       fullgraph=True,
                                       backend=backend_func)
    gen_func = llm.generate(prompts, sampling_params)

    model_runner.model = torch.compile(orig_model,
                                       fullgraph=True,
                                       backend=backend_no_func)
87

88
89
90
91
92
93
94
    gen_no_func = llm.generate(prompts, sampling_params)

    for output_func, output_no_func in zip(gen_func, gen_no_func):
        assert output_func.outputs[0].text == output_no_func.outputs[0].text

    # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
    # and replaced by fused quantized ops in RMS_QUANT_OPS.
95
96
    rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
               ] if do_fusion else [RMS_OP]
97
98
99
100
101
102
    silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
        quant_key == kFp8StaticTensorSym else [
        SILU_MUL_OP
    ]

    ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
103
104
105
106
107
108
109
110
111
112
113
114
115

    for op in ops:
        find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
        assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
                                  op) is None  # noqa: E501

    # make sure the ops were all de-functionalized
    found = dict()
    for node in backend_func.graph_post_pass.nodes:
        for op in ops:
            if is_func(node, op):
                found[op] = True
    assert all(found[op] for op in ops)