"vllm/vscode:/vscode.git/clone" did not exist on "ada6f91d561ab693fd85f028bc44b8c8058d3073"
conftest.py 1.42 KB
Newer Older
1
import contextlib
2
import functools
3
import gc
4
from typing import Callable, TypeVar
5
6
7
8

import pytest
import ray
import torch
9
from typing_extensions import ParamSpec
10
11
12
13
14
15

from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig


16
@pytest.fixture(autouse=True)
17
18
19
20
21
def cleanup():
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
22
    ray.shutdown()
23
24
25
26
    gc.collect()
    torch.cuda.empty_cache()


27
28
_P = ParamSpec("_P")
_R = TypeVar("_R")
29

30
31
32
33

def retry_until_skip(n: int):

    def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
34

35
        @functools.wraps(func)
36
        def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
37
38
39
40
41
42
43
            for i in range(n):
                try:
                    return func(*args, **kwargs)
                except AssertionError:
                    gc.collect()
                    torch.cuda.empty_cache()
                    if i == n - 1:
44
45
46
                        pytest.skip(f"Skipping test after {n} attempts.")

            raise AssertionError("Code should not be reached")
47

48
49
50
        return wrapper_retry

    return decorator_retry
51
52
53
54
55


@pytest.fixture(autouse=True)
def tensorizer_config():
    config = TensorizerConfig(tensorizer_uri="vllm")
56
    return config