conftest.py 1.37 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
import gc
5
from typing import Callable, TypeVar
6
7
8

import pytest
import torch
9
from typing_extensions import ParamSpec
10

11
from vllm.distributed import cleanup_dist_env_and_memory
12
13
14
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig


15
16
17
18
19
20
21
22
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    Tensorizer only tested on V0 so far.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


23
@pytest.fixture(autouse=True)
24
def cleanup():
25
    cleanup_dist_env_and_memory(shutdown_ray=True)
26
27


28
29
_P = ParamSpec("_P")
_R = TypeVar("_R")
30

31
32
33
34

def retry_until_skip(n: int):

    def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
35

36
        @functools.wraps(func)
37
        def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
38
39
40
41
42
43
44
            for i in range(n):
                try:
                    return func(*args, **kwargs)
                except AssertionError:
                    gc.collect()
                    torch.cuda.empty_cache()
                    if i == n - 1:
45
46
47
                        pytest.skip(f"Skipping test after {n} attempts.")

            raise AssertionError("Code should not be reached")
48

49
50
51
        return wrapper_retry

    return decorator_retry
52
53
54
55
56


@pytest.fixture(autouse=True)
def tensorizer_config():
    config = TensorizerConfig(tensorizer_uri="vllm")
57
    return config