test_cudagraph_mode.py 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
11
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
    ("FA3", "FULL", True),
    ("FA3", "FULL_AND_PIECEWISE", True),
    ("FA2", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FA2", "FULL_AND_PIECEWISE", True),
    ("FlashInfer", "FULL", True),  # Should fallback to FULL_AND_PIECEWISE
    ("FlashInfer", "FULL_AND_PIECEWISE", True),
]


48
49
50
51
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
                         combo_cases_1)
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
                                          supported):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    if backend_name == "FlashInfer":
        try:
            import flashinfer  # noqa: F401
        except ImportError:
            pytest.skip("FlashInfer is not installed")
    backend_config = backend_configs[backend_name]
    # Dynamically skip test if GPU capability is not met
    if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
        != current_platform.get_device_capability():
        pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")

    env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

        llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
                  max_num_seqs=256,
                  trust_remote_code=True,
                  gpu_memory_utilization=0.45,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      level=3, cudagraph_mode=cudagraph_mode))
        llm.generate(["Hello, my name is"] * 10)
77
    # when above code raises, `llm` may be undefined, so we need to catch that
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
    ("FA2", "FULL", 0, True),  # no compilation + full cudagraph
    ("FA2", "FULL", 3, True),  # piecewise compilation + full cudagraph
    ("FA2", "PIECEWISE", 0, False),  # no compilation + piecewise cudagraph
    ("FA2", "PIECEWISE", 3,
     True),  # piecewise compilation + piecewise cudagraph
    ("FA2", "FULL_AND_PIECEWISE", 0,
     False),  # piecewise cudagraph not supported without piecewise compilation
    ("FA2", "FULL_AND_PIECEWISE", 3, True),
    ("FA2", "FULL_DECODE_ONLY", 0, True),
    ("FA2", "FULL_DECODE_ONLY", 3, True),
    ("FA2", "NONE", 0, True),  # no compilation + no cudagraph
    ("FA2", "NONE", 3, True),  # piecewise compilation + no cudagraph
]


108
109
@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
                         "supported", combo_cases_2)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def test_cudagraph_compilation_combo(combo_case):
    backend_name, cudagraph_mode, compilation_level, supported\
        = combo_case

    env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}

    with temporary_environ(env_vars), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(Exception))

        llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
                  max_num_seqs=256,
                  trust_remote_code=True,
                  gpu_memory_utilization=0.45,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      level=compilation_level, cudagraph_mode=cudagraph_mode))
        llm.generate(["Hello, my name is"] * 10)
128
    # when above code raises, `llm` may be undefined, so we need to catch that
129
130
131
132
133
134
135
136
137
138
    try:
        llm = weakref.proxy(llm)
        del llm
    except UnboundLocalError:
        pass
    finally:
        wait_for_gpu_memory_to_clear(
            devices=[0],
            threshold_ratio=0.1,
        )