conftest.py 1.17 KB
Newer Older
1
import contextlib
2
import functools
3
4
5
6
7
8
9
10
11
12
13
import gc

import pytest
import ray
import torch

from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig


14
@pytest.fixture(autouse=True)
15
16
17
18
19
def cleanup():
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
20
    ray.shutdown()
21
22
23
24
    gc.collect()
    torch.cuda.empty_cache()


25
def retry_until_skip(n):
26

27
    def decorator_retry(func):
28

29
30
31
32
33
34
35
36
37
38
        @functools.wraps(func)
        def wrapper_retry(*args, **kwargs):
            for i in range(n):
                try:
                    return func(*args, **kwargs)
                except AssertionError:
                    gc.collect()
                    torch.cuda.empty_cache()
                    if i == n - 1:
                        pytest.skip("Skipping test after attempts..")
39

40
41
42
        return wrapper_retry

    return decorator_retry
43
44
45
46
47


@pytest.fixture(autouse=True)
def tensorizer_config():
    config = TensorizerConfig(tensorizer_uri="vllm")
48
    return config