test_cudagraph_mode.py 5.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
11
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
12
from vllm import LLM
13
from vllm.config import CompilationConfig, CompilationMode
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.platforms import current_platform


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
if current_platform.is_rocm():
    combo_cases_1 = [
        ("RocmAttn", "FULL", True),
        ("RocmAttn", "FULL_AND_PIECEWISE", True),
        ("TritonAttn", "FULL", True),
        ("TritonAttn", "FULL_AND_PIECEWISE", True),
    ]
else:
    combo_cases_1 = [
        ("FA3", "FULL", True),
        ("FA3", "FULL_AND_PIECEWISE", True),
        ("FA2", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
        ("FA2", "FULL_AND_PIECEWISE", True),
        ("FlashInfer", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
        ("FlashInfer", "FULL_AND_PIECEWISE", True),
    ]
54
55


56
57
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supported):
58
59
60
61
62
63
64
    if backend_name == "FlashInfer":
        try:
            import flashinfer  # noqa: F401
        except ImportError:
            pytest.skip("FlashInfer is not installed")
    backend_config = backend_configs[backend_name]
    # Dynamically skip test if GPU capability is not met
65
66
67
68
    if (
        backend_config.specific_gpu_arch
        and backend_config.specific_gpu_arch != current_platform.get_device_capability()
    ):
69
70
        pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")

71
    env_vars = backend_configs[backend_name].env_vars
72
73
74
75
76

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

77
78
79
80
81
82
83
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
84
                mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
85
86
            ),
        )
87
        llm.generate(["Hello, my name is"] * 10)
88
    # when above code raises, `llm` may be undefined, so we need to catch that
89
90
91
92
93
94
95
96
97
98
99
100
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


101
102
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
if current_platform.is_rocm():
    combo_cases_2 = [
        ("RocmAttn", "FULL", CompilationMode.NONE, True),
        ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
        ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
        ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
        ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
        ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
        ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
        ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
        ("RocmAttn", "NONE", CompilationMode.NONE, True),
        ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
    ]
else:
    combo_cases_2 = [
        ("FA2", "FULL", CompilationMode.NONE, True),
        ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
120
        ("FA2", "PIECEWISE", CompilationMode.NONE, True),
121
        ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
122
        ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
123
124
125
126
127
128
        ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
        ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
        ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
        ("FA2", "NONE", CompilationMode.NONE, True),
        ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
    ]
129
130


131
@pytest.mark.parametrize(
132
    "backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
133
)
134
135
136
def test_cudagraph_compilation_combo(
    backend_name, cudagraph_mode, compilation_mode, supported
):
137
    env_vars = backend_configs[backend_name].env_vars
138
139
140
141
142

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

143
144
145
146
147
148
149
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
150
                mode=compilation_mode, cudagraph_mode=cudagraph_mode
151
152
            ),
        )
153
        llm.generate(["Hello, my name is"] * 10)
154
    # when above code raises, `llm` may be undefined, so we need to catch that
155
156
157
158
159
160
161
162
163
164
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass
    finally:
        wait_for_gpu_memory_to_clear(
            devices=[0],
            threshold_ratio=0.1,
        )