test_full_graph.py 658 Bytes
Newer Older
1
2
import pytest

3
from vllm.compilation.levels import CompilationLevel
4

5
from ..utils import fork_new_process_for_each_test
6
from .utils import TEST_MODELS, check_full_graph_support
7
8


9
@pytest.mark.parametrize("model_info", TEST_MODELS)
10
11
12
13
14
@pytest.mark.parametrize(
    "optimization_level",
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
15
16
    model = model_info[0]
    model_kwargs = model_info[1]
17
18
19
20
    check_full_graph_support(model,
                             model_kwargs,
                             optimization_level,
                             tp_size=1)