helper.py 1.22 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
from functools import wraps
Kai Zhang's avatar
Kai Zhang committed
7
from tempfile import TemporaryDirectory
facebook-github-bot's avatar
facebook-github-bot committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

import torch
import torch.distributed as dist


def skip_if_no_gpu(func):
    """Decorator that can be used to skip GPU tests on non-GPU machines."""
    func.skip_if_no_gpu = True

    @wraps(func)
    def wrapper(*args, **kwargs):
        if not torch.cuda.is_available():
            return
        if torch.cuda.device_count() <= 0:
            return

        return func(*args, **kwargs)

    return wrapper


def enable_ddp_env(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        dist.init_process_group(
            "gloo",
            rank=0,
            world_size=1,
            init_method="file:///tmp/detectron2go_test_ddp_init",
        )
        ret = func(*args, **kwargs)
        dist.destroy_process_group()
        return ret

    return wrapper
Kai Zhang's avatar
Kai Zhang committed
44
45
46
47
48
49
50
51
52
53

def tempdir(func):
    """ A decorator for creating a tempory directory that is cleaned up after function execution. """

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        with TemporaryDirectory() as temp:
            return func(self, temp, *args, **kwargs)

    return wrapper