conftest.py 1.16 KB
Newer Older
1
import functools
2
import gc
3
from typing import Callable, TypeVar
4
5
6

import pytest
import torch
7
from typing_extensions import ParamSpec
8

9
from vllm.distributed import cleanup_dist_env_and_memory
10
11
12
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig


13
@pytest.fixture(autouse=True)
14
def cleanup():
15
    cleanup_dist_env_and_memory(shutdown_ray=True)
16
17


18
19
_P = ParamSpec("_P")
_R = TypeVar("_R")
20

21
22
23
24

def retry_until_skip(n: int):

    def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
25

26
        @functools.wraps(func)
27
        def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
28
29
30
31
32
33
34
            for i in range(n):
                try:
                    return func(*args, **kwargs)
                except AssertionError:
                    gc.collect()
                    torch.cuda.empty_cache()
                    if i == n - 1:
35
36
37
                        pytest.skip(f"Skipping test after {n} attempts.")

            raise AssertionError("Code should not be reached")
38

39
40
41
        return wrapper_retry

    return decorator_retry
42
43
44
45
46


@pytest.fixture(autouse=True)
def tensorizer_config():
    config = TensorizerConfig(tensorizer_uri="vllm")
47
    return config