common.py 6.17 KB
Newer Older
Shaden Smith's avatar
Shaden Smith committed
1
2
3
4
5
6
7
import os
import time

import torch
import torch.distributed as dist
from torch.multiprocessing import Process

8
9
import deepspeed

Shaden Smith's avatar
Shaden Smith committed
10
import pytest
aiss's avatar
aiss committed
11
12
13
14
15
from functools import wraps
import unittest
from pathlib import Path

from pathlib import Path
Shaden Smith's avatar
Shaden Smith committed
16

17
# Worker timeout *after* the first worker has completed.
Jeff Rasley's avatar
Jeff Rasley committed
18
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
Shaden Smith's avatar
Shaden Smith committed
19
20


aiss's avatar
aiss committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def get_xdist_worker_id():
    xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
    if xdist_worker is not None:
        xdist_worker_id = xdist_worker.replace('gw', '')
        return int(xdist_worker_id)
    return None


def get_master_port():
    master_port = os.environ.get('DS_TEST_PORT', '29503')
    xdist_worker_id = get_xdist_worker_id()
    if xdist_worker_id is not None:
        master_port = str(int(master_port) + xdist_worker_id)
    return master_port


def set_cuda_visibile():
    cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    xdist_worker_id = get_xdist_worker_id()
    if xdist_worker_id is None:
        xdist_worker_id = 0
    if cuda_visible is None:
        # CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead
        import subprocess
        is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None
        if is_rocm_pytorch:
            rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
            gpu_ids = filter(lambda s: 'GPU' in s,
                             rocm_smi.decode('utf-8').strip().split('\n'))
            num_gpus = len(list(gpu_ids))
        else:
            nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
            num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n'))
        cuda_visible = ",".join(map(str, range(num_gpus)))

    # rotate list based on xdist worker id, example below
    # wid=0 -> ['0', '1', '2', '3']
    # wid=1 -> ['1', '2', '3', '0']
    # wid=2 -> ['2', '3', '0', '1']
    # wid=3 -> ['3', '0', '1', '2']
    dev_id_list = cuda_visible.split(",")
    dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)


66
def distributed_test(world_size=2, backend='nccl'):
Shaden Smith's avatar
Shaden Smith committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """A decorator for executing a function (e.g., a unit test) in a distributed manner.
    This decorator manages the spawning and joining of processes, initialization of
    torch.distributed, and catching of errors.

    Usage example:
        @distributed_test(worker_size=[2,3])
        def my_test():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            assert(rank < world_size)

    Arguments:
        world_size (int or list): number of ranks to spawn. Can be a list to spawn
        multiple tests.
    """
    def dist_wrap(run_func):
        """Second-level decorator for dist_test. This actually wraps the function. """
        def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
            """Initialize torch.distributed and execute the user function. """
            os.environ['MASTER_ADDR'] = '127.0.0.1'
aiss's avatar
aiss committed
87
            os.environ['MASTER_PORT'] = get_master_port()
88
89
90
91
92
            os.environ['LOCAL_RANK'] = str(local_rank)
            # NOTE: unit tests don't support multi-node so local_rank == global rank
            os.environ['RANK'] = str(local_rank)
            os.environ['WORLD_SIZE'] = str(num_procs)

aiss's avatar
aiss committed
93
94
95
96
97
            # turn off NCCL logging if set
            os.environ.pop('NCCL_DEBUG', None)

            set_cuda_visibile()

98
            deepspeed.init_distributed(dist_backend=backend)
Shaden Smith's avatar
Shaden Smith committed
99

100
101
            if torch.cuda.is_available():
                torch.cuda.set_device(local_rank)
Shaden Smith's avatar
Shaden Smith committed
102
103
104

            run_func(*func_args, **func_kwargs)

aiss's avatar
aiss committed
105
106
107
108
109
110
            # make sure all ranks finish at the same time
            torch.distributed.barrier()

            # tear down after test completes
            torch.distributed.destroy_process_group()

Shaden Smith's avatar
Shaden Smith committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        def dist_launcher(num_procs, *func_args, **func_kwargs):
            """Launch processes and gracefully handle failures. """

            # Spawn all workers on subprocesses.
            processes = []
            for local_rank in range(num_procs):
                p = Process(target=dist_init,
                            args=(local_rank,
                                  num_procs,
                                  *func_args),
                            kwargs=func_kwargs)
                p.start()
                processes.append(p)

            # Now loop and wait for a test to complete. The spin-wait here isn't a big
            # deal because the number of processes will be O(#GPUs) << O(#CPUs).
            any_done = False
            while not any_done:
                for p in processes:
                    if not p.is_alive():
                        any_done = True
                        break

            # Wait for all other processes to complete
            for p in processes:
                p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)

            failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
            for rank, p in failed:
                # If it still hasn't terminated, kill it because it hung.
                if p.exitcode is None:
                    p.terminate()
                    pytest.fail(f'Worker {rank} hung.', pytrace=False)
                if p.exitcode < 0:
                    pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
                                pytrace=False)
                if p.exitcode > 0:
                    pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
                                pytrace=False)

        def run_func_decorator(*func_args, **func_kwargs):
            """Entry point for @distributed_test(). """

            if isinstance(world_size, int):
                dist_launcher(world_size, *func_args, **func_kwargs)
            elif isinstance(world_size, list):
                for procs in world_size:
                    dist_launcher(procs, *func_args, **func_kwargs)
                    time.sleep(0.5)
            else:
                raise TypeError(f'world_size must be an integer or a list of integers.')

        return run_func_decorator

    return dist_wrap
aiss's avatar
aiss committed
166
167
168
169
170


def get_test_path(filename):
    curr_path = Path(__file__).parent
    return str(curr_path.joinpath(filename))