test_full_graph.py 410 Bytes
Newer Older
1
2
import pytest

3
from vllm.compilation.backends import vllm_backend
4

5
from .utils import TEST_MODELS, check_full_graph_support
6
7


8
9
10
11
12
13
@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
    model = model_info[0]
    model_kwargs = model_info[1]
    check_full_graph_support(model, model_kwargs, backend, tp_size=1)