test_full_cudagraph.py 5.02 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import contextlib
import os
5
6
import weakref
from contextlib import ExitStack
7
8
9

import pytest

10
from tests.utils import wait_for_gpu_memory_to_clear
11
12
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
13
from vllm.platforms import current_platform
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


35
36
37
@pytest.fixture(scope="class")
def llm_pair(request):
    model = request.param
38
39
40
41
42

    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        full = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(full_cuda_graph=True),
        )
        piecewise = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(),
        )

    # PyTest caches the fixture values so we use weakref.proxy to enable GC
    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


@pytest.mark.parametrize(
    "llm_pair",
    [
        # Model names for the llm_pair fixture
        "deepseek-ai/DeepSeek-V2-Lite",
        "Qwen/Qwen2-1.5B-Instruct"
    ],
    indirect=True)
77
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
78
79
                    reason="Only Hopper GPUs support FA3 and FlashMLA")
class TestFullCUDAGraph:
80
    """
81
82
    Use a class such that an llm pair is constructed once for all
    batch_size/max_tokens combinations and released immediately after.
83

84
85
    Module-scope fixtures would stick around the whole time,
    meaning there would be multiple LLM instances hogging memory simultaneously.
86
87
    """

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @pytest.mark.parametrize(("batch_size", "max_tokens"), [
        (1, 10),
        (7, 10),
        (16, 10),
        (25, 10),
        (32, 10),
        (45, 10),
        (64, 10),
        (123, 10),
        (8, 5),
        (8, 30),
    ])
    def test_full_cudagraph(self, batch_size, max_tokens,
                            llm_pair: tuple[LLM, LLM]):
        """
        Test various batch sizes and max_tokens to ensure that the
        full cudagraph compilation works for padded cases too.
        """

        piecewise_llm, full_cudagraph_llm = llm_pair

        prompts = ["Hello, my name is"] * batch_size
        sampling_params = SamplingParams(temperature=0.0,
                                         max_tokens=max_tokens,
                                         top_p=0.95)

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        # Check that all responses are the same
        for piecewise_res, full_res in zip(piecewise_responses,
                                           full_responses):
            assert piecewise_res.outputs[0].text == full_res.outputs[0].text


@pytest.mark.parametrize(
    "model, supported",
    [
        ("Qwen/Qwen2-1.5B-Instruct", True),
        # MLA does not support capturing CUDA Graphs with size > max_num_seqs
        ("deepseek-ai/DeepSeek-V2-Lite", False),
    ])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
                    reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(RuntimeError))

        llm = LLM(model=model,
                  max_num_seqs=256,
                  trust_remote_code=True,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      full_cuda_graph=True,
                      cudagraph_capture_sizes=[64, 256, 512]))
        llm.generate(["Hello, my name is"] * 10)
148
149


150
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
151
152
153
154
155
156
def test_full_cudagraph_with_invalid_backend():
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION":
            "2"  #FA2 not supported with full_cuda_graph
    }), pytest.raises(RuntimeError):
157
        LLM(model="Qwen/Qwen2-1.5B-Instruct",
158
            compilation_config=CompilationConfig(full_cuda_graph=True))