test_full_graph.py 686 Bytes
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import pytest

5
from vllm.config import CompilationLevel
6

7
from ..utils import fork_new_process_for_each_test
8
from .utils import TEST_MODELS, check_full_graph_support
9
10


11
@pytest.mark.parametrize("model_info", TEST_MODELS)
12
13
@pytest.mark.parametrize(
    "optimization_level",
14
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE])
15
16
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
17
18
    model = model_info[0]
    model_kwargs = model_info[1]
19
20
21
22
    check_full_graph_support(model,
                             model_kwargs,
                             optimization_level,
                             tp_size=1)