test_cudagraph_mode.py 4.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
11
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
    ("FA3", "FULL", True),
    ("FA3", "FULL_AND_PIECEWISE", True),
    ("FA2", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FA2", "FULL_AND_PIECEWISE", True),
    ("FlashInfer", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FlashInfer", "FULL_AND_PIECEWISE", True),
]


48
49
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supported):
50
51
52
53
54
55
56
    if backend_name == "FlashInfer":
        try:
            import flashinfer  # noqa: F401
        except ImportError:
            pytest.skip("FlashInfer is not installed")
    backend_config = backend_configs[backend_name]
    # Dynamically skip test if GPU capability is not met
57
58
59
60
    if (
        backend_config.specific_gpu_arch
        and backend_config.specific_gpu_arch != current_platform.get_device_capability()
    ):
61
62
        pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")

63
    env_vars = backend_configs[backend_name].env_vars
64
65
66
67
68

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

69
70
71
72
73
74
75
76
77
78
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
                level=3, cudagraph_mode=cudagraph_mode
            ),
        )
79
        llm.generate(["Hello, my name is"] * 10)
80
    # when above code raises, `llm` may be undefined, so we need to catch that
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
    ("FA2", "FULL", 0, True),  # no compilation + full cudagraph
    ("FA2", "FULL", 3, True),  # piecewise compilation + full cudagraph
    ("FA2", "PIECEWISE", 0, False),  # no compilation + piecewise cudagraph
99
100
101
102
103
104
105
    ("FA2", "PIECEWISE", 3, True),  # piecewise compilation + piecewise cudagraph
    (
        "FA2",
        "FULL_AND_PIECEWISE",
        0,
        False,
    ),  # piecewise cudagraph not supported without piecewise compilation
106
107
108
109
110
111
112
113
    ("FA2", "FULL_AND_PIECEWISE", 3, True),
    ("FA2", "FULL_DECODE_ONLY", 0, True),
    ("FA2", "FULL_DECODE_ONLY", 3, True),
    ("FA2", "NONE", 0, True),  # no compilation + no cudagraph
    ("FA2", "NONE", 3, True),  # piecewise compilation + no cudagraph
]


114
115
116
@pytest.mark.parametrize(
    "backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2
)
117
def test_cudagraph_compilation_combo(combo_case):
118
    backend_name, cudagraph_mode, compilation_level, supported = combo_case
119

120
    env_vars = backend_configs[backend_name].env_vars
121
122
123
124
125

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

126
127
128
129
130
131
132
133
134
135
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
                level=compilation_level, cudagraph_mode=cudagraph_mode
            ),
        )
136
        llm.generate(["Hello, my name is"] * 10)
137
    # when above code raises, `llm` may be undefined, so we need to catch that
138
139
140
141
142
143
144
145
146
147
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass
    finally:
        wait_for_gpu_memory_to_clear(
            devices=[0],
            threshold_ratio=0.1,
        )