testing.py 26.6 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
12
#   http://www.apache.org/licenses/LICENSE-2.0
Tom Birch's avatar
Tom Birch committed
13
14
15
16
17
18
19
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

20
21
22
23
# We're not responsible for pytest decorators
# mypy: disallow_untyped_decorators = False

"""
24
25
26
Collection of some testing utilities for the Fairscale library. Please complement as
you see fit, but refrain from ad-hoc test utils within the different feature sets and
relative imports.
27
28
"""

29
import contextlib
Tom Birch's avatar
Tom Birch committed
30
import functools
31
import gc
Tom Birch's avatar
Tom Birch committed
32
import inspect
33
import logging
34
import multiprocessing
Tom Birch's avatar
Tom Birch committed
35
36
import os
import random
37
from statistics import mean
38
import subprocess
39
import sys
40
import tempfile
41
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Tom Birch's avatar
Tom Birch committed
42
43

import numpy
Tom Birch's avatar
Tom Birch committed
44
import pytest
Tom Birch's avatar
Tom Birch committed
45
import torch
46
from torch import Tensor
Tom Birch's avatar
Tom Birch committed
47
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
48
from torch.distributed import rpc
Tom Birch's avatar
Tom Birch committed
49
import torch.multiprocessing as mp
50
import torch.nn as nn
Tom Birch's avatar
Tom Birch committed
51

52
from fairscale.internal import torch_version
53
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
Tom Birch's avatar
Tom Birch committed
54
55
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed

56
57
58
59
60
if TYPE_CHECKING:
    Base = nn.Module[Tensor]
else:
    Base = nn.Module

61
62
skip_if_cuda = pytest.mark.skipif(torch.cuda.is_available(), reason="Testing only on CPUs to save time")

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
63
64
65
66
67
68
69
70
skip_if_no_cuda = pytest.mark.skipif(
    not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
)

skip_if_single_gpu = pytest.mark.skipif(
    not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)

71
skip_if_less_than_four_gpu = pytest.mark.skipif(
72
73
74
    not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
)

75
76
77
78
skip_if_py38 = pytest.mark.skipif(
    sys.version_info.major == 3 and sys.version_info.minor == 8, reason="Python3.8 is skipped"
)

79
80
skip_if_py39_no_cuda = pytest.mark.skipif(
    not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9,
81
    reason="Python3.9 without CUDA is skipped",
82
83
)

84
85
86
87
88
available_devices = ["cpu"]
if torch.cuda.is_available():
    available_devices.append("cuda")


89
filename_mpi: Optional[str] = None
90

Tom Birch's avatar
Tom Birch committed
91

92
class IdentityLayer(Base):
93
    def __init__(self, size: int, scale: float = 1.0) -> None:
Tom Birch's avatar
Tom Birch committed
94
95
96
        super(IdentityLayer, self).__init__()
        self.weight = torch.nn.Parameter(scale * torch.randn(size))

97
    def forward(self, *_: Any, **__: Any) -> Tensor:
Tom Birch's avatar
Tom Birch committed
98
99
100
        return self.weight


101
def set_random_seed(seed: int, model_parallel: bool = True) -> None:
102
    """Set random seed for reproducibility."""
Tom Birch's avatar
Tom Birch committed
103
104
105
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
106
107
108
109
110
111
    if model_parallel:
        model_parallel_cuda_manual_seed(seed)


def in_circle_ci() -> bool:
    return os.path.exists("/home/circleci")
Tom Birch's avatar
Tom Birch committed
112
113


114
115
# Global variable to cache the results from the first nvidia-smi execution.
_smi_ver: Optional[str] = None
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]:
    if compiled:
        numbering = torch.version.cuda.split(".")[:2]
    else:
        global _smi_ver
        if _smi_ver is None:

            def get_smi_ver() -> str:
                """Get CUDA version from nvidia-smi"""
                for line in subprocess.check_output("nvidia-smi".split()).decode("utf-8").split("\n"):
                    if "CUDA Version" in line:
                        res = line.split()[8]
                        assert res.startswith("10.") or res.startswith("11."), res
                        return res
                assert False

            _smi_ver = get_smi_ver()
        numbering = _smi_ver.split(".")[:2]
    return tuple(int(n) for n in numbering)


139
140
141
142
143
144
145
146
147
def make_cudnn_deterministic() -> None:
    """Make cudnn (matmul) deterministic"""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # TF32 also make things nondeterministic. Disable it.
    torch.backends.cuda.matmul.allow_tf32 = False  # type: ignore
    torch.backends.cudnn.allow_tf32 = False  # type: ignore


Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
148
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
149
150
151
152
    """
    Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
    tests to be run concurrently.

153
154
    Return false if not enough GPUs present in the system.

155
156
157
    .. warning: This limits the usecase to all ranks being on the same node
    """

158
159
160
161
162
    try:
        torch.distributed.rpc.shutdown()
    except Exception:
        pass

163
    print(f"dist init r={rank}, world={world_size}")
164

Tom Birch's avatar
Tom Birch committed
165
166
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
167
    url = "file://" + filename
168
    url_rpc = "file://" + filename_rpc
Tom Birch's avatar
Tom Birch committed
169

170
    if torch_version() >= (1, 6, 0):
171
        backend = "nccl" if torch.cuda.is_available() else "gloo"
172
173
174
175
        if backend == "nccl" and torch.cuda.device_count() < world_size:
            logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
            return False

176
177
        torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)

178
179
180
        tp_options = {"init_method": url_rpc}
        # Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
        if torch_version() == (1, 8, 0):
181
182
183
184
185
186
            if torch.cuda.is_available():
                # Workaround for https://github.com/pytorch/pytorch/issues/53844
                tp_options["_transports"] = ["ibv", "uv"]  # type: ignore
            else:
                # Workaround for https://github.com/pytorch/pytorch/issues/54266
                tp_options["_channels"] = ["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"]  # type: ignore
187

Tom Birch's avatar
Tom Birch committed
188
189
190
191
192
        rpc.init_rpc(
            f"Test{rank}",
            rank=rank,
            world_size=world_size,
            backend=rpc.BackendType.TENSORPIPE,
193
            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(**tp_options),
Tom Birch's avatar
Tom Birch committed
194
        )
195

Tom Birch's avatar
Tom Birch committed
196
197
    else:
        if world_size > 1:
198
199
200
201
202
203
204
            # TensorPipe is not available in Torch 1.5
            rpc.init_rpc(
                name=f"Test{rank}",
                rank=rank,
                world_size=world_size,
                rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=url_rpc),
            )
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
205
        elif torch.cuda.is_available():
206
            torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
207
208
        else:
            return False
Tom Birch's avatar
Tom Birch committed
209
210
211
212

    if torch.cuda.is_available() and torch.cuda.device_count():
        torch.cuda.set_device(rank % torch.cuda.device_count())

213
214
    return True

Tom Birch's avatar
Tom Birch committed
215

216
def get_worker_map() -> Dict[Any, Any]:
Tom Birch's avatar
Tom Birch committed
217
    return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
Tom Birch's avatar
Tom Birch committed
218
219


220
def get_world_sizes() -> List[int]:
Tom Birch's avatar
Tom Birch committed
221
222
223
224
    limit = torch.cuda.device_count()
    return [x for x in [1, 2, 4, 8] if x <= limit]


225
226
227
228
229
def test_runner(
    rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any]
) -> None:
    # At this point we're in a new process, torch options need to be set again
    if deterministic:
230
        make_cudnn_deterministic()
231
232
233
        torch.manual_seed(1357)

    test_func(rank, *args, **kwargs)
234

235
236
237
238

def spawn_for_all_world_sizes(
    test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = [], deterministic: bool = False
) -> None:
Tom Birch's avatar
Tom Birch committed
239
    for world_size in world_sizes:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
240
241
242
        _, filename = tempfile.mkstemp()
        _, filename_rpc = tempfile.mkstemp()

243
244
245
        try:
            # (lefaudeux) Let mp handle the process joining, join=False and handling context has
            # been unstable in the past.
246
247
248
249
250
251
            mp.spawn(
                test_runner,
                args=(test_func, deterministic, world_size, filename, filename_rpc, *args),
                nprocs=world_size,
                join=True,
            )
252
253
254
        finally:
            rmf(filename)
            rmf(filename_rpc)
Tom Birch's avatar
Tom Birch committed
255
256


Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
257
258
259
def worker_process(
    rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any
) -> None:
260
    """Main function for unit tests launched with torch_spawn"""
261

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
262
263
    if not dist_init(rank, world_size, filename, filename_rpc):
        logging.warning("failed initializing torch distributed")
264
        teardown()
265
266
        return

267
268
269
    kwargs = {}
    if "OMPI_COMM_WORLD_RANK" not in os.environ:
        kwargs["pipeline_backend"] = "gloo"
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
270

271
    initialize_model_parallel(1, world_size, **kwargs)
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
272

273
274
275
276
277
278
279
280
    # Make sure that CUDA operations are repeatable
    context = (
        torch.backends.cudnn.flags(benchmark=False, deterministic=True)  # type: ignore
        if torch.cuda.is_available() and hasattr(torch.backends.cudnn, "flags")
        else contextlib.suppress()
    )

    if torch.cuda.is_available() and not hasattr(torch.backends.cudnn, "flags"):
281
        make_cudnn_deterministic()
282

283
    try:
284
285
        with context:
            func(*args)
286
        teardown()
287
    except BaseException as e:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
288
289
        logging.warning(f" Rank {rank}: {e}")

290
291
292
        # Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
        teardown()

293
294
295
296
297
        # If the function raises 'Skipped', this indicates pytest.skip(), so
        # forward it to parent so we can call pytest.skip() there
        if e.__class__.__name__ == "Skipped":
            error_queue.put(str(e))
            return
298

299
        raise e
Tom Birch's avatar
Tom Birch committed
300

301
302

def teardown() -> None:
303
    destroy_model_parallel()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
304

305
306
307
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    try:
308
309
        # torch 1.5 hangs on shutdown if waiting for all processes
        torch.distributed.rpc.shutdown(graceful=False)
310
311
312
    except Exception:
        pass

Tom Birch's avatar
Tom Birch committed
313

314
def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
Tom Birch's avatar
Tom Birch committed
315
316
317
    if world_sizes is None:
        world_sizes = get_world_sizes()

318
    def prepare_test(func: Callable) -> Callable:
319
320
        """Function called with the test function as the argument. Generates a
        replacement which serves as the actual test function."""
Tom Birch's avatar
Tom Birch committed
321
322
323
324
325
326
327
328
329
330
331

        name = func.__name__
        parameters = inspect.signature(func).parameters

        if name.startswith("test"):
            raise ValueError(
                f"Tests marked with @torch_spawn (i.e. '{name}') should not have names beginning in 'test' as they will"
                " be picked up by pytest without running the spawn wrapper"
            )

        @functools.wraps(func)
332
        def replacement(*args: Any, **kwargs: Any) -> None:
Tom Birch's avatar
Tom Birch committed
333
            assert args == tuple()
334
335
            assert world_sizes is not None  # mypy crutch

Tom Birch's avatar
Tom Birch committed
336
337
338
339
            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

340
            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
Tom Birch's avatar
Tom Birch committed
341
            if "OMPI_COMM_WORLD_RANK" in os.environ:
342
343
344
345
346
                # TODO (Min): this global used to be assigned every time this file is imported.
                #     I changed it to be assigned on first use. Should be the same, but I am not
                #     sure this is used or is correct since different processes would have different
                #     file names to init_process_group below. By initing, here, we don't leave
                #     a temp file behind on importing time.
347
                global filename_mpi
348
349
                if filename_mpi is None:
                    filename_mpi = tempfile.mkstemp()[1]
350

351
352
                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
353
354
                torch.distributed.init_process_group("mpi", init_method=f"file://{filename_mpi}")

Tom Birch's avatar
Tom Birch committed
355
                world_size = torch.distributed.get_world_size()
356
                destroy_model_parallel()
Tom Birch's avatar
Tom Birch committed
357
358
359
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
                if world_size in world_sizes:
360
361
                    try:
                        func(*args)
362
                        teardown()
363
                    except BaseException as e:
364
                        teardown()
365
366
367
368
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
Tom Birch's avatar
Tom Birch committed
369
                else:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
370
                    pytest.skip("Requested world size doesn't match current world size")
Tom Birch's avatar
Tom Birch committed
371
            else:
372
373
374
375
376
                spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
Tom Birch's avatar
Tom Birch committed
377

378
379
        # Register a function with the same name, prefixed with "test_" in the
        # calling module, so it will be picked up by pytest
380
381
382
        current_frame = inspect.currentframe()
        assert current_frame is not None
        caller_module = inspect.getmodule(current_frame.f_back)
Tom Birch's avatar
Tom Birch committed
383
384
385
386
        setattr(caller_module, f"test_{name}", replacement)

        return func

387
    return prepare_test
388
389


390
class _Block(Base):
391
392
393
394
395
    def __init__(self, embed_dim: int, num_heads: int) -> None:
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)  # type: ignore
396
397
398
399
400
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )
401

402
    def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
403
404
405
406
407
408
409
410
411
412
413
414
        x = inputs[0]
        attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
        attn_mask = torch.triu(attn_mask, diagonal=1)

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


415
class GPT2(Base):
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
416
417
418
419
    """
    GPT2 pytorch implementation, for testing purposes in the image-GPT context
    Credits: https://github.com/teddykoker/image-gpt"""

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def __init__(
        self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int
    ) -> None:
        super().__init__()

        self.embed_dim = embed_dim

        # start of sequence token
        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
        self.position_embeddings = nn.Embedding(num_positions, embed_dim)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(_Block(embed_dim, num_heads))

        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_vocab, bias=False)
        self.clf_head = nn.Linear(embed_dim, num_classes)

442
    def forward(self, x: Tensor, classify: bool = False) -> Any:  # type: ignore
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        """
        Expect input as shape [sequence len, batch]
        If classify, return classification logits
        """
        length, batch = x.shape

        h = self.token_embeddings(x)

        # prepend sos token
        sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
        h = torch.cat([sos, h[:-1, :, :]], dim=0)

        # add positional embeddings
        positions = torch.arange(length, device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)

        # transformer
        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)

        logits = self.head(h)

        if not classify:
            # return logits
            return logits

        h = torch.mean(h, dim=0)  # average pool over sequence
        # return classification logits and generative logits
        return self.clf_head(h), logits
Myle Ott's avatar
Myle Ott committed
474
475


476
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
Myle Ott's avatar
Myle Ott committed
477
478
479
480
481
    """
    Test that two objects are equal. Tensors are compared to ensure matching
    size, dtype, device and values.
    """
    if type(a) is not type(b):
482
483
        if raise_exception:
            raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
Myle Ott's avatar
Myle Ott committed
484
485
486
        return False
    if isinstance(a, dict):
        if set(a.keys()) != set(b.keys()):
487
488
            if raise_exception:
                raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
Myle Ott's avatar
Myle Ott committed
489
490
            return False
        for k in a.keys():
491
            if not objects_are_equal(a[k], b[k], raise_exception, k):
Myle Ott's avatar
Myle Ott committed
492
493
494
495
                return False
        return True
    elif isinstance(a, (list, tuple, set)):
        if len(a) != len(b):
496
497
            if raise_exception:
                raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
Myle Ott's avatar
Myle Ott committed
498
499
500
501
502
            return False
        return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
    elif torch.is_tensor(a):
        try:
            # assert_allclose doesn't strictly test shape, dtype and device
503
            shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
504
505
506
507
508
509
510
511
512
513
            if not shape_dtype_device_match:
                if raise_exception:
                    msg = f"sizes: {a.size()} vs. {b.size()}, "
                    msg += f"types: {a.dtype} vs. {b.dtype}, "
                    msg += f"device: {a.device} vs. {b.device}"
                    raise AssertionError(msg)
                else:
                    return False
            # assert_allclose.
            torch.testing.assert_allclose(a, b)
Myle Ott's avatar
Myle Ott committed
514
            return True
515
        except (AssertionError, RuntimeError) as e:
Myle Ott's avatar
Myle Ott committed
516
            if raise_exception:
517
518
519
520
                if dict_key and isinstance(e, AssertionError):
                    # Add dict key to the assertion error.
                    msg = e.args[0]
                    new_msg = f"For dict key '{dict_key}': {msg}"
521
                    raise AssertionError(new_msg) from None
522
523
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
524
525
526
527
            else:
                return False
    else:
        return a == b
528
529
530
531
532
533
534
535


def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
    for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
        assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message

    for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
        assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551


def check_same_models_across_ranks(
    model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
    world_size = dist.get_world_size(process_group)
    rank = dist.get_rank(process_group)
    for param in model.parameters():
        # collect the params across the rank
        receptacle = [param.clone() for _ in range(world_size)]
        dist.all_gather(receptacle, param, group=process_group)

        if rank == 0:
            for sync_p in receptacle[1:]:
                assert not params_should_be_equal or torch.all(
                    torch.eq(receptacle[0], sync_p)
552
                ), f"Models differ in between ranks {receptacle[0]} - {sync_p}"
553
554
555
556
557
558
559
560
561
562

    # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
    if check_broadcast_buffers:
        for buffer in model.buffers():
            receptacle = [buffer.clone() for _ in range(world_size)]
            dist.all_gather(receptacle, buffer, group=process_group)
            if rank == 0:
                for sync_b in receptacle[1:]:
                    assert not params_should_be_equal or torch.all(
                        torch.eq(receptacle[0], sync_b)
563
                    ), f"Models differ in between ranks {receptacle[0]} - {sync_b}"
564
565
566
567
568
569
570
571
572
573
574
575
576


class DeviceAndTypeCheckModule(Base):
    """A simple module for checking Tensor devices and dtypes."""

    def __init__(
        self,
        expected_input_dtype: Optional[torch.dtype] = None,
        expected_input_device: Optional[torch.device] = None,
        expected_param_dtype: Optional[torch.dtype] = None,
        expected_param_device: Optional[torch.device] = None,
        expected_loss_dtype: Optional[torch.dtype] = None,
        expected_loss_device: Optional[torch.device] = None,
577
        expected_buffer_dtype: Optional[torch.device] = None,
578
579
580
581
582
583
584
585
    ):
        super().__init__()
        self.expected_input_dtype = expected_input_dtype
        self.expected_input_device = expected_input_device
        self.expected_param_dtype = expected_param_dtype
        self.expected_param_device = expected_param_device
        self.expected_loss_dtype = expected_loss_dtype
        self.expected_loss_device = expected_loss_device
586
        self.expected_buffer_dtype = expected_buffer_dtype
587
588

        self.linear = nn.Linear(5, 5)
589
        self.register_buffer("buffer", torch.rand((5,)))
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

    def _check(
        self,
        key: str,
        x: Union[torch.device, torch.dtype],
        expected: Union[Optional[torch.device], Optional[torch.dtype]],
    ) -> None:
        assert expected in {None, x}, f"{key} ({x}) != expected ({expected})"

    def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
        x = input[0]
        self._check("input.dtype", x.dtype, self.expected_input_dtype)
        self._check("input.device", x.device, self.expected_input_device)

        param = self.linear.weight
        self._check("param.dtype", param.dtype, self.expected_param_dtype)
        self._check("param.device", param.device, self.expected_param_device)
607
608
609
        self._check("buffer.dtype", self.buffer.dtype, self.expected_buffer_dtype)  # type: ignore
        x = x + self.buffer
        loss = (self.linear(x) + self.buffer).sum()
610
611
612
613
614
615
616
617
        self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
        self._check("loss.device", loss.device, self.expected_loss_device)

        return loss


@functools.lru_cache()
def get_cycles_per_ms() -> float:
618
    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
619
620
621

    Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
    """
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

    def measure() -> float:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        torch.cuda._sleep(1000000)
        end.record()
        end.synchronize()
        cycles_per_ms = 1000000 / start.elapsed_time(end)
        return cycles_per_ms

    # Get 10 values and remove the 2 max and 2 min and return the avg.
    # This is to avoid system disturbance that skew the results, e.g.
    # the very first cuda call likely does a bunch of init, which takes
    # much longer than subsequent calls.
    #
    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
    # and seems to return stable values. Therefore, we enable caching
    # using lru_cache decorator above.
    num = 10
    vals = []
    for _ in range(num):
        vals.append(measure())
    vals = sorted(vals)
    return mean(vals[2 : num - 2])
647
648
649
650
651
652
653
654
655
656
657
658


class DummyProcessGroup:
    def __init__(self, rank: int, size: int):
        self._rank = rank
        self._size = size

    def rank(self) -> int:
        return self._rank

    def size(self) -> int:
        return self._size
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682


class SGDWithPausingCompute(torch.optim.SGD):
    def __init__(self, *args, **kwargs) -> None:  # type: ignore
        self.rank = kwargs["rank"]
        del kwargs["rank"]

        super().__init__(*args, **kwargs)

    def step(self, closure: Optional[Any] = None) -> Any:
        loss = super().step(closure=closure)

        # This is used to make sure that OSS and ShardedDDP enforce a proper stream synchronization
        # - Add a long cuda wait on a compute stream, non blocking from the CPU perspective
        with torch.cuda.stream(torch.cuda.Stream()):
            torch.cuda._sleep(100000000)

            # - optionally change the params on a per rank basis
            with torch.no_grad():
                for param_group in self.param_groups:
                    for param in param_group["params"]:
                        param *= 1.0 + self.rank / 10.0

        return loss
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700


def state_dict_norm(state: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Compute the norm from a state_dict for simple comparison."""
    norm = torch.zeros(1)
    for v in state.values():
        if not v.is_floating_point():
            v = v.float()
        norm += v.norm()
    return norm


def rmf(filename: str) -> None:
    """Remove a file like rm -f."""
    try:
        os.remove(filename)
    except FileNotFoundError:
        pass
701
702


703
704
705
706
707
708
709
710
711
@contextlib.contextmanager
def in_temporary_directory() -> Generator:
    """
    Context manager to create a temporary direction and remove
    it at the end of the context
    """
    old_cwd = os.getcwd()
    with tempfile.TemporaryDirectory() as temp_dir:
        os.chdir(temp_dir)
712
713
714
715
        try:
            yield temp_dir
        finally:
            os.chdir(old_cwd)
716
717


718
719
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
720
    """A context to get tempfiles and ensure they are cleaned up."""
721
722
    files = [tempfile.mkstemp()[1] for _ in range(num)]

723
724
725
726
727
728
    try:
        yield tuple(files)
    finally:
        # temp files could have been removed, so we use rmf.
        for name in files:
            rmf(name)
729
730
731
732
733
734
735
736
737
738
739


def dump_all_tensors(rank: int) -> None:
    """Useful tool for debugging memory issues from the python side."""
    if rank != 0:
        return
    for obj in gc.get_objects():
        try:
            ttype = str(type(obj))
            if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
                print(ttype, obj.shape, obj.dtype, obj.device, obj.storage().size())
740
        except Exception:
741
742
            pass
    print(torch.cuda.memory_summary())
743
744
745
746
747
748
749
750
751
752
753
754


def get_smi_memory() -> float:
    """Return process's GPU memory in MB."""
    pid = os.getpid()
    info_string = torch.cuda.list_gpu_processes()
    for line in info_string.splitlines():
        if str(pid) in line:
            toks = line.split()
            return float(toks[3])
    # If the process is not in the list, we are not using the GPU.
    return 0.0
755
756
757
758
759
760


def skip_a_test_if_in_CI() -> None:
    """Skip a test in circle CI"""
    if os.path.exists("/home/circleci"):
        pytest.skip("Sometimes a CI test failure is not reproducible locally, we skip them")