test_config.py 1.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
                         set_current_vllm_config)

from .piecewise.test_simple import SillyModel


14
15
16
17
18
19
20
21
22
23
def test_use_cudagraphs_dynamic(monkeypatch):
    assert vllm.envs.VLLM_USE_V1
    vllm_config = VllmConfig()
    assert vllm_config.compilation_config.use_cudagraph

    monkeypatch.setenv('VLLM_USE_V1', '0')
    vllm_config = VllmConfig()
    assert not vllm_config.compilation_config.use_cudagraph


24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(enabled):
    assert vllm.envs.VLLM_USE_V1
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_cudagraph=enabled,
        cudagraph_capture_sizes=[100],
    ))
    with set_current_vllm_config(vllm_config):
        model = SillyModel(vllm_config=vllm_config, prefix='')

    inputs = torch.randn(100, device="cuda")

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_cudagraph_captured=1 if enabled else 0,
    ):
        # first run is warmup
        model(inputs)
        # second run does CUDAGraphs recording (if enabled)
        model(inputs)