test_cudagraph_mode.py 6.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict
    specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
backend_configs = {
    # FA3 on Hopper
    "FA3":
    BackendConfig(name="FA3",
50
51
52
53
                  env_vars={
                      "VLLM_FLASH_ATTN_VERSION": "3",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                  comp_config={
                      "cudagraph_mode": "FULL",
                  },
                  specific_gpu_arch=(9, 0)),
    # FlashMLA on Hopper
    "FlashMLA":
    BackendConfig(name="FlashMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASHMLA",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  },
                  specific_gpu_arch=(9, 0)),
68
69
70
71
72
    # FlashAttention MLA on Hopper
    "FlashAttentionMLA":
    BackendConfig(name="FlashAttentionMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
73
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
74
75
76
77
78
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_DECODE_ONLY",
                  },
                  specific_gpu_arch=(9, 0)),
79
80
81
    # FA2
    "FA2":
    BackendConfig(name="FA2",
82
83
84
85
                  env_vars={
                      "VLLM_FLASH_ATTN_VERSION": "2",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
    # Triton Attention
    "TritonAttn":
    BackendConfig(name="TritonAttn",
                  env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
    # FlashInfer
    "FlashInfer":
    BackendConfig(name="FlashInfer",
                  env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
}

# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
    ("FA3", "FULL", True),
    ("FA3", "FULL_AND_PIECEWISE", True),
    ("FA2", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FA2", "FULL_AND_PIECEWISE", True),
    ("FlashInfer", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FlashInfer", "FULL_AND_PIECEWISE", True),
]


@pytest.mark.parametrize("combo_case", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(combo_case):
    backend_name, cudagraph_mode, supported = combo_case
    if backend_name == "FlashInfer":
        try:
            import flashinfer  # noqa: F401
        except ImportError:
            pytest.skip("FlashInfer is not installed")
    backend_config = backend_configs[backend_name]
    # Dynamically skip test if GPU capability is not met
    if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
        != current_platform.get_device_capability():
        pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")

    env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

        llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
                  max_num_seqs=256,
                  trust_remote_code=True,
                  gpu_memory_utilization=0.45,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      level=3, cudagraph_mode=cudagraph_mode))
        llm.generate(["Hello, my name is"] * 10)

    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
    ("FA2", "FULL", 0, True),  # no compilation + full cudagraph
    ("FA2", "FULL", 3, True),  # piecewise compilation + full cudagraph
    ("FA2", "PIECEWISE", 0, False),  # no compilation + piecewise cudagraph
    ("FA2", "PIECEWISE", 3,
     True),  # piecewise compilation + piecewise cudagraph
    ("FA2", "FULL_AND_PIECEWISE", 0,
     False),  # piecewise cudagraph not supported without piecewise compilation
    ("FA2", "FULL_AND_PIECEWISE", 3, True),
    ("FA2", "FULL_DECODE_ONLY", 0, True),
    ("FA2", "FULL_DECODE_ONLY", 3, True),
    ("FA2", "NONE", 0, True),  # no compilation + no cudagraph
    ("FA2", "NONE", 3, True),  # piecewise compilation + no cudagraph
]


@pytest.mark.parametrize("combo_case", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case):
    backend_name, cudagraph_mode, compilation_level, supported\
        = combo_case

    env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

        llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
                  max_num_seqs=256,
                  trust_remote_code=True,
                  gpu_memory_utilization=0.45,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      level=compilation_level, cudagraph_mode=cudagraph_mode))
        llm.generate(["Hello, my name is"] * 10)
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass
    finally:
        wait_for_gpu_memory_to_clear(
            devices=[0],
            threshold_ratio=0.1,
        )