helper.py 2.15 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


Sam Tsai's avatar
Sam Tsai committed
5
import importlib
6
7
8
import os
import socket
import uuid
facebook-github-bot's avatar
facebook-github-bot committed
9
from functools import wraps
Kai Zhang's avatar
Kai Zhang committed
10
from tempfile import TemporaryDirectory
Sam Tsai's avatar
Sam Tsai committed
11
from typing import Optional
facebook-github-bot's avatar
facebook-github-bot committed
12
13

import torch
14
from d2go.distributed import distributed_worker, DistributedParams
facebook-github-bot's avatar
facebook-github-bot committed
15
16


Sam Tsai's avatar
Sam Tsai committed
17
18
19
20
21
22
23
24
25
26
27
def get_resource_path(file: Optional[str] = None):
    path_list = [
        os.path.dirname(importlib.import_module("d2go.tests").__file__),
        "resources",
    ]
    if file is not None:
        path_list.append(file)

    return os.path.join(*path_list)


facebook-github-bot's avatar
facebook-github-bot committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def skip_if_no_gpu(func):
    """Decorator that can be used to skip GPU tests on non-GPU machines."""
    func.skip_if_no_gpu = True

    @wraps(func)
    def wrapper(*args, **kwargs):
        if not torch.cuda.is_available():
            return
        if torch.cuda.device_count() <= 0:
            return

        return func(*args, **kwargs)

    return wrapper


def enable_ddp_env(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
47
48
49
50
        def find_free_port() -> str:
            s = socket.socket()
            s.bind(("localhost", 0))  # Bind to a free port provided by the host.
            return str(s.getsockname()[1])
facebook-github-bot's avatar
facebook-github-bot committed
51

52
53
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = find_free_port()
54
55
56
57
58
59

        return distributed_worker(
            main_func=func,
            args=args,
            kwargs=kwargs,
            backend="gloo",
Yanghan Wang's avatar
Yanghan Wang committed
60
            init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
61
62
                uuid.uuid4().hex
            ),
63
64
65
66
67
68
69
70
            dist_params=DistributedParams(
                local_rank=0,
                machine_rank=0,
                global_rank=0,
                num_processes_per_machine=1,
                world_size=1,
            ),
            return_save_file=None,  # don't save file
facebook-github-bot's avatar
facebook-github-bot committed
71
72
73
        )

    return wrapper
Kai Zhang's avatar
Kai Zhang committed
74

Yanghan Wang's avatar
Yanghan Wang committed
75

Kai Zhang's avatar
Kai Zhang committed
76
def tempdir(func):
Yanghan Wang's avatar
Yanghan Wang committed
77
    """A decorator for creating a tempory directory that is cleaned up after function execution."""
Kai Zhang's avatar
Kai Zhang committed
78
79
80
81
82
83
84

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        with TemporaryDirectory() as temp:
            return func(self, temp, *args, **kwargs)

    return wrapper