test_cudagraph_mode.py 5.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
11
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
12
from vllm import LLM
13
from vllm.config import CompilationConfig, CompilationMode
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.platforms import current_platform


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
if current_platform.is_rocm():
    combo_cases_1 = [
        ("RocmAttn", "FULL", True),
        ("RocmAttn", "FULL_AND_PIECEWISE", True),
        ("TritonAttn", "FULL", True),
        ("TritonAttn", "FULL_AND_PIECEWISE", True),
    ]
else:
    combo_cases_1 = [
        ("FA3", "FULL", True),
        ("FA3", "FULL_AND_PIECEWISE", True),
        ("FA2", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
        ("FA2", "FULL_AND_PIECEWISE", True),
        ("FlashInfer", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
        ("FlashInfer", "FULL_AND_PIECEWISE", True),
    ]
54
55


56
57
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supported):
58
59
60
61
62
63
64
    if backend_name == "FlashInfer":
        try:
            import flashinfer  # noqa: F401
        except ImportError:
            pytest.skip("FlashInfer is not installed")
    backend_config = backend_configs[backend_name]
    # Dynamically skip test if GPU capability is not met
65
66
67
68
    if (
        backend_config.specific_gpu_arch
        and backend_config.specific_gpu_arch != current_platform.get_device_capability()
    ):
69
70
        pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")

71
    env_vars = backend_configs[backend_name].env_vars
72
73
74
75
76

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

77
78
79
80
81
82
83
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
84
                mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
85
86
            ),
        )
87
        llm.generate(["Hello, my name is"] * 10)
88
    # when above code raises, `llm` may be undefined, so we need to catch that
89
90
91
92
93
94
95
96
97
98
99
100
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


101
102
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2"

combo_cases_2 = [
    (attn_backend, "FULL", CompilationMode.NONE, True),
    (attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True),
    (attn_backend, "PIECEWISE", CompilationMode.NONE, True),
    (attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
    (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
    (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
    (attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True),
    (attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
    (attn_backend, "NONE", CompilationMode.NONE, True),
    (attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True),
]
117
118


119
@pytest.mark.parametrize(
120
    "backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
121
)
122
123
124
def test_cudagraph_compilation_combo(
    backend_name, cudagraph_mode, compilation_mode, supported
):
125
    env_vars = backend_configs[backend_name].env_vars
126
127
128
129
130

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

131
132
133
134
135
136
137
        llm = LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            max_num_seqs=256,
            trust_remote_code=True,
            gpu_memory_utilization=0.45,
            max_model_len=1024,
            compilation_config=CompilationConfig(
138
                mode=compilation_mode, cudagraph_mode=cudagraph_mode
139
140
            ),
        )
141
        llm.generate(["Hello, my name is"] * 10)
142
    # when above code raises, `llm` may be undefined, so we need to catch that
143
144
145
146
147
148
149
150
151
152
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass
    finally:
        wait_for_gpu_memory_to_clear(
            devices=[0],
            threshold_ratio=0.1,
        )