utils.py 10.1 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import contextlib
import os
import re
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch.cuda
import torch.multiprocessing as mp
from nanotron.parallel import ParallelContext
from packaging import version


def available_gpus():
    if not torch.cuda.is_available():
        return 0

    device_properties = [torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count())]

    # We filter out
    blacklisted_gpu_names = {"NVIDIA DGX Display"}
    device_properties = [property_ for property_ in device_properties if property_.name not in blacklisted_gpu_names]

    # TODO @thomasw21: Can we do this cross node
    return len(device_properties)


# from https://stackoverflow.com/a/34333710/9201239
@contextlib.contextmanager
def mock_os_environ(remove_keys: List[str] = None, update_key_values: Dict[str, Any] = None):
    """
    Temporarily updates the ``os.environ`` dictionary in-place.
    The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
    Args:
      remove_keys: Environment variables to remove.
      update_key_values: Dictionary of environment variables and values to add/update.
    """
    env = os.environ
    update_key_values = update_key_values or {}
    remove_keys = remove_keys or []

    update_keys = set(update_key_values.keys())
    remove_keys = set(remove_keys)
    assert remove_keys.isdisjoint(update_keys)

    stomped = (update_keys | remove_keys) & set(env.keys())
    reverse_change = {
        # Environment variables and values to restore on exit.
        **{k: env[k] for k in update_keys & stomped},
        # Environment variables and values to remove on exit.
        **{k: env[k] for k in remove_keys & stomped},
    }

    try:
        env.update(update_key_values)
        for k in remove_keys:
            env.pop(k, None)
        yield
    finally:
        env.update(reverse_change)


def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> Tuple[bool, Optional[str]]:
    """Returns True or False if the dictionaries match, and an additional message when it's False"""
    if sub_paths is None:
        sub_paths = []

    first_keys = set(first.keys())
    second_keys = set(second.keys())
    if first_keys != second_keys:
        return False, f"Keys don't match in {'.'.join(sub_paths)}.\nCur: {first_keys}\nRef: {second_keys}"
    for key in first_keys:
        first_elt = first[key]
        second_elt = second[key]

        if isinstance(first_elt, dict):
            if not isinstance(second_elt, dict):
                return (
                    False,
                    f"Object types don't match in {'.'.join(sub_paths +  [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}",
                )
            match, msg = is_dict_equal(first_elt, second_elt, sub_paths=sub_paths + [str(key)])
            if match is False:
                return False, msg
        elif isinstance(first_elt, torch.Tensor):
            if not isinstance(second_elt, torch.Tensor):
                return (
                    False,
                    f"Object types don't match in {'.'.join(sub_paths +  [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}",
                )
            try:
                torch.testing.assert_close(
                    first_elt,
                    second_elt,
                    atol=0.0,
                    rtol=0.0,
                    msg=lambda msg: f"Tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}",
                )
            except AssertionError as error:
                return False, error.args[0]
        else:
            if first_elt != second_elt:
                return (
                    False,
                    f"Objects at key {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}",
                )

    return True, None


def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
    """Given a number of gpus, we want all 3d configurations possible such that pp * dp * tp = gpus"""
    result = []
    for tp in range(1, gpus + 1):
        if gpus % tp != 0:
            continue
        gpus_left_after_tp = gpus // tp
        for dp in range(1, gpus_left_after_tp + 1):
            if gpus_left_after_tp % dp != 0:
                continue
            gpus_left_after_dp = gpus_left_after_tp // dp
            for pp in range(1, gpus_left_after_dp + 1):
                if gpus_left_after_dp % pp != 0:
                    continue
                if tp * dp * pp == gpus:
                    result.append((pp, dp, tp))
    return result


def rerun_if_address_is_in_use(max_try: int = 500):
    """
    This function reruns a wrapped function if "address already in use" occurs
    in testing spawned with torch.multiprocessing

    Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157

    Usage::

        @rerun_if_address_is_in_use()
        def test_something():
            ...

    """
    # check version
    torch_version = version.parse(torch.__version__)
    assert torch_version.major >= 1

    # only torch >= 1.8 has ProcessRaisedException
    if torch_version >= version.parse("1.8.0"):
        exception = torch.multiprocessing.ProcessRaisedException
    else:
        exception = Exception

    func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try)
    return func_wrapper


def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable:
    """
    A decorator on a function to re-run when an exception occurs.

    Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71

    Usage::

        # rerun for all kinds of exception
        @rerun_on_exception()
        def test_method():
            print('hey')
            raise RuntimeError('Address already in use')

        # rerun for RuntimeError only
        @rerun_on_exception(exception_type=RuntimeError)
        def test_method():
            print('hey')
            raise RuntimeError('Address already in use')

        # rerun for maximum 10 times if Runtime error occurs
        @rerun_on_exception(exception_type=RuntimeError, max_try=10)
        def test_method():
            print('hey')
            raise RuntimeError('Address already in use')

        # rerun for infinite times if Runtime error occurs
        @rerun_on_exception(exception_type=RuntimeError, max_try=None)
        def test_method():
            print('hey')
            raise RuntimeError('Address already in use')

        # rerun only the exception message is matched with pattern
        # for infinite times if Runtime error occurs
        @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
        def test_method():
            print('hey')
            raise RuntimeError('Address already in use')

    Args:
        exception_type (Exception, Optional): The type of exception to detect for rerun
        pattern (str, Optional): The pattern to match the exception message.
            If the pattern is not None and matches the exception message,
            the exception will be detected for rerun
        max_try (int, Optional): Maximum reruns for this function. The default value is 5.
            If max_try is None, it will rerun forever if exception keeps occurring
    """

    def _match_lines(lines, pattern):
        for line in lines:
            if re.match(pattern, line):
                return True
        return False

    def _wrapper(func):
        def _run_until_success(*args, **kwargs):
            try_count = 0
            assert max_try is None or isinstance(
                max_try, int
            ), f"Expected max_try to be None or int, but got {type(max_try)}"

            while max_try is None or try_count < max_try:
                try:
                    try_count += 1
                    ret = func(*args, **kwargs)
                    return ret
                except exception_type as e:
                    error_lines = str(e).split("\n")
                    if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):

                        print("Exception is caught, retrying...")
                        # when pattern is not specified, we always skip the exception
                        # when pattern is specified, we only skip when pattern is matched
                        continue
                    else:
                        print("Maximum number of attempts is reached or pattern is not matched, no more retrying...")
                        raise e

        # Override signature
        # otherwise pytest.mark.parameterize will raise the following error:
        # function does not use argument xxx
        sig = signature(func)
        _run_until_success.__signature__ = sig

        return _run_until_success

    return _wrapper


def global_wrapper(rank, func, tp, pp, dp, port, kwargs):
    def setup_dist_env(rank, world_size, port):
        os.environ["WORLD_SIZE"] = str(world_size)
        os.environ["RANK"] = str(rank)
        # NOTE: since we do unit tests in a
        # single node => this is fine!
        os.environ["LOCAL_RANK"] = str(rank)
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(port)

    world_size = tp * pp * dp
    setup_dist_env(rank, world_size, port)
    parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp)
    func(parallel_context, **kwargs)


def init_distributed(tp: int, dp: int, pp: int):
    def _init_distributed(func):
        def wrapper(**kwargs):
            from nanotron.utils import find_free_port

            world_size = tp * pp * dp
            port = find_free_port()

            # Note that kwargs needs to be passed as part of args in a way that can be unpacked
            args = (func, tp, pp, dp, port, kwargs)
            mp.spawn(global_wrapper, args=args, nprocs=world_size)

        return wrapper

    return _init_distributed