testing.py 26.7 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
12
#   http://www.apache.org/licenses/LICENSE-2.0
Tom Birch's avatar
Tom Birch committed
13
14
15
16
17
18
19
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

20
21
22
23
# We're not responsible for pytest decorators
# mypy: disallow_untyped_decorators = False

"""
24
25
26
Collection of some testing utilities for the Fairscale library. Please complement as
you see fit, but refrain from ad-hoc test utils within the different feature sets and
relative imports.
27
28
"""

29
import contextlib
Tom Birch's avatar
Tom Birch committed
30
import functools
31
import gc
Tom Birch's avatar
Tom Birch committed
32
import inspect
33
import logging
34
import multiprocessing
Tom Birch's avatar
Tom Birch committed
35
36
import os
import random
37
from statistics import mean
38
import subprocess
39
import sys
40
import tempfile
41
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Tom Birch's avatar
Tom Birch committed
42
43

import numpy
Tom Birch's avatar
Tom Birch committed
44
import pytest
Tom Birch's avatar
Tom Birch committed
45
import torch
46
from torch import Tensor
Tom Birch's avatar
Tom Birch committed
47
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
48
from torch.distributed import rpc
Tom Birch's avatar
Tom Birch committed
49
import torch.multiprocessing as mp
50
import torch.nn as nn
Tom Birch's avatar
Tom Birch committed
51

52
from fairscale.internal import torch_version
53
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
Tom Birch's avatar
Tom Birch committed
54
55
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed

56
57
58
59
60
if TYPE_CHECKING:
    Base = nn.Module[Tensor]
else:
    Base = nn.Module

61
62
skip_if_cuda = pytest.mark.skipif(torch.cuda.is_available(), reason="Testing only on CPUs to save time")

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
63
64
65
66
67
68
69
70
skip_if_no_cuda = pytest.mark.skipif(
    not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
)

skip_if_single_gpu = pytest.mark.skipif(
    not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)

71
skip_if_less_than_four_gpu = pytest.mark.skipif(
72
73
74
    not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
)

75
76
77
78
skip_if_py38 = pytest.mark.skipif(
    sys.version_info.major == 3 and sys.version_info.minor == 8, reason="Python3.8 is skipped"
)

79
80
skip_if_py39_no_cuda = pytest.mark.skipif(
    not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9,
81
    reason="Python3.9 without CUDA is skipped",
82
83
)

Min Xu's avatar
Min Xu committed
84
85
86
87
skip_due_to_flakyness = pytest.mark.skip(
    reason="Flaky test to be fixed or removed",
)

88
89
90
91
92
available_devices = ["cpu"]
if torch.cuda.is_available():
    available_devices.append("cuda")


93
filename_mpi: Optional[str] = None
94

Tom Birch's avatar
Tom Birch committed
95

96
class IdentityLayer(Base):
97
    def __init__(self, size: int, scale: float = 1.0) -> None:
Tom Birch's avatar
Tom Birch committed
98
99
100
        super(IdentityLayer, self).__init__()
        self.weight = torch.nn.Parameter(scale * torch.randn(size))

101
    def forward(self, *_: Any, **__: Any) -> Tensor:
Tom Birch's avatar
Tom Birch committed
102
103
104
        return self.weight


105
def set_random_seed(seed: int, model_parallel: bool = True) -> None:
106
    """Set random seed for reproducibility."""
Tom Birch's avatar
Tom Birch committed
107
108
109
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
110
111
112
113
114
115
    if model_parallel:
        model_parallel_cuda_manual_seed(seed)


def in_circle_ci() -> bool:
    return os.path.exists("/home/circleci")
Tom Birch's avatar
Tom Birch committed
116
117


118
119
# Global variable to cache the results from the first nvidia-smi execution.
_smi_ver: Optional[str] = None
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142


def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]:
    if compiled:
        numbering = torch.version.cuda.split(".")[:2]
    else:
        global _smi_ver
        if _smi_ver is None:

            def get_smi_ver() -> str:
                """Get CUDA version from nvidia-smi"""
                for line in subprocess.check_output("nvidia-smi".split()).decode("utf-8").split("\n"):
                    if "CUDA Version" in line:
                        res = line.split()[8]
                        assert res.startswith("10.") or res.startswith("11."), res
                        return res
                assert False

            _smi_ver = get_smi_ver()
        numbering = _smi_ver.split(".")[:2]
    return tuple(int(n) for n in numbering)


143
144
145
146
147
148
149
150
151
def make_cudnn_deterministic() -> None:
    """Make cudnn (matmul) deterministic"""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # TF32 also make things nondeterministic. Disable it.
    torch.backends.cuda.matmul.allow_tf32 = False  # type: ignore
    torch.backends.cudnn.allow_tf32 = False  # type: ignore


Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
152
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
153
154
155
156
    """
    Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
    tests to be run concurrently.

157
158
    Return false if not enough GPUs present in the system.

159
160
161
    .. warning: This limits the usecase to all ranks being on the same node
    """

162
163
164
165
166
    try:
        torch.distributed.rpc.shutdown()
    except Exception:
        pass

167
    print(f"dist init r={rank}, world={world_size}")
168

Tom Birch's avatar
Tom Birch committed
169
170
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
171
    url = "file://" + filename
172
    url_rpc = "file://" + filename_rpc
Tom Birch's avatar
Tom Birch committed
173

174
    if torch_version() >= (1, 6, 0):
175
        backend = "nccl" if torch.cuda.is_available() else "gloo"
176
177
178
179
        if backend == "nccl" and torch.cuda.device_count() < world_size:
            logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
            return False

180
181
        torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)

182
183
184
        tp_options = {"init_method": url_rpc}
        # Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
        if torch_version() == (1, 8, 0):
185
186
187
188
189
190
            if torch.cuda.is_available():
                # Workaround for https://github.com/pytorch/pytorch/issues/53844
                tp_options["_transports"] = ["ibv", "uv"]  # type: ignore
            else:
                # Workaround for https://github.com/pytorch/pytorch/issues/54266
                tp_options["_channels"] = ["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"]  # type: ignore
191

Tom Birch's avatar
Tom Birch committed
192
193
194
195
196
        rpc.init_rpc(
            f"Test{rank}",
            rank=rank,
            world_size=world_size,
            backend=rpc.BackendType.TENSORPIPE,
197
            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(**tp_options),
Tom Birch's avatar
Tom Birch committed
198
        )
199

Tom Birch's avatar
Tom Birch committed
200
201
    else:
        if world_size > 1:
202
203
204
205
206
207
208
            # TensorPipe is not available in Torch 1.5
            rpc.init_rpc(
                name=f"Test{rank}",
                rank=rank,
                world_size=world_size,
                rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=url_rpc),
            )
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
209
        elif torch.cuda.is_available():
210
            torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
211
212
        else:
            return False
Tom Birch's avatar
Tom Birch committed
213
214
215
216

    if torch.cuda.is_available() and torch.cuda.device_count():
        torch.cuda.set_device(rank % torch.cuda.device_count())

217
218
    return True

Tom Birch's avatar
Tom Birch committed
219

220
def get_worker_map() -> Dict[Any, Any]:
Tom Birch's avatar
Tom Birch committed
221
    return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
Tom Birch's avatar
Tom Birch committed
222
223


224
def get_world_sizes() -> List[int]:
Tom Birch's avatar
Tom Birch committed
225
226
227
228
    limit = torch.cuda.device_count()
    return [x for x in [1, 2, 4, 8] if x <= limit]


229
230
231
232
233
def test_runner(
    rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any]
) -> None:
    # At this point we're in a new process, torch options need to be set again
    if deterministic:
234
        make_cudnn_deterministic()
235
236
237
        torch.manual_seed(1357)

    test_func(rank, *args, **kwargs)
238

239
240
241
242

def spawn_for_all_world_sizes(
    test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = [], deterministic: bool = False
) -> None:
Tom Birch's avatar
Tom Birch committed
243
    for world_size in world_sizes:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
244
245
246
        _, filename = tempfile.mkstemp()
        _, filename_rpc = tempfile.mkstemp()

247
248
249
        try:
            # (lefaudeux) Let mp handle the process joining, join=False and handling context has
            # been unstable in the past.
250
251
252
253
254
255
            mp.spawn(
                test_runner,
                args=(test_func, deterministic, world_size, filename, filename_rpc, *args),
                nprocs=world_size,
                join=True,
            )
256
257
258
        finally:
            rmf(filename)
            rmf(filename_rpc)
Tom Birch's avatar
Tom Birch committed
259
260


Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
261
262
263
def worker_process(
    rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any
) -> None:
264
    """Main function for unit tests launched with torch_spawn"""
265

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
266
267
    if not dist_init(rank, world_size, filename, filename_rpc):
        logging.warning("failed initializing torch distributed")
268
        teardown()
269
270
        return

271
272
273
    kwargs = {}
    if "OMPI_COMM_WORLD_RANK" not in os.environ:
        kwargs["pipeline_backend"] = "gloo"
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
274

275
    initialize_model_parallel(1, world_size, **kwargs)
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
276

277
278
279
280
281
282
283
284
    # Make sure that CUDA operations are repeatable
    context = (
        torch.backends.cudnn.flags(benchmark=False, deterministic=True)  # type: ignore
        if torch.cuda.is_available() and hasattr(torch.backends.cudnn, "flags")
        else contextlib.suppress()
    )

    if torch.cuda.is_available() and not hasattr(torch.backends.cudnn, "flags"):
285
        make_cudnn_deterministic()
286

287
    try:
288
289
        with context:
            func(*args)
290
        teardown()
291
    except BaseException as e:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
292
293
        logging.warning(f" Rank {rank}: {e}")

294
295
296
        # Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
        teardown()

297
298
299
300
301
        # If the function raises 'Skipped', this indicates pytest.skip(), so
        # forward it to parent so we can call pytest.skip() there
        if e.__class__.__name__ == "Skipped":
            error_queue.put(str(e))
            return
302

303
        raise e
Tom Birch's avatar
Tom Birch committed
304

305
306

def teardown() -> None:
307
    destroy_model_parallel()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
308

309
310
311
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    try:
312
313
        # torch 1.5 hangs on shutdown if waiting for all processes
        torch.distributed.rpc.shutdown(graceful=False)
314
315
316
    except Exception:
        pass

Tom Birch's avatar
Tom Birch committed
317

318
def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
Tom Birch's avatar
Tom Birch committed
319
320
321
    if world_sizes is None:
        world_sizes = get_world_sizes()

322
    def prepare_test(func: Callable) -> Callable:
323
324
        """Function called with the test function as the argument. Generates a
        replacement which serves as the actual test function."""
Tom Birch's avatar
Tom Birch committed
325
326
327
328
329
330
331
332
333
334
335

        name = func.__name__
        parameters = inspect.signature(func).parameters

        if name.startswith("test"):
            raise ValueError(
                f"Tests marked with @torch_spawn (i.e. '{name}') should not have names beginning in 'test' as they will"
                " be picked up by pytest without running the spawn wrapper"
            )

        @functools.wraps(func)
336
        def replacement(*args: Any, **kwargs: Any) -> None:
Tom Birch's avatar
Tom Birch committed
337
            assert args == tuple()
338
339
            assert world_sizes is not None  # mypy crutch

Tom Birch's avatar
Tom Birch committed
340
341
342
343
            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

344
            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
Tom Birch's avatar
Tom Birch committed
345
            if "OMPI_COMM_WORLD_RANK" in os.environ:
346
347
348
349
350
                # TODO (Min): this global used to be assigned every time this file is imported.
                #     I changed it to be assigned on first use. Should be the same, but I am not
                #     sure this is used or is correct since different processes would have different
                #     file names to init_process_group below. By initing, here, we don't leave
                #     a temp file behind on importing time.
351
                global filename_mpi
352
353
                if filename_mpi is None:
                    filename_mpi = tempfile.mkstemp()[1]
354

355
356
                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
357
358
                torch.distributed.init_process_group("mpi", init_method=f"file://{filename_mpi}")

Tom Birch's avatar
Tom Birch committed
359
                world_size = torch.distributed.get_world_size()
360
                destroy_model_parallel()
Tom Birch's avatar
Tom Birch committed
361
362
363
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
                if world_size in world_sizes:
364
365
                    try:
                        func(*args)
366
                        teardown()
367
                    except BaseException as e:
368
                        teardown()
369
370
371
372
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
Tom Birch's avatar
Tom Birch committed
373
                else:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
374
                    pytest.skip("Requested world size doesn't match current world size")
Tom Birch's avatar
Tom Birch committed
375
            else:
376
377
378
379
380
                spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
Tom Birch's avatar
Tom Birch committed
381

382
383
        # Register a function with the same name, prefixed with "test_" in the
        # calling module, so it will be picked up by pytest
384
385
386
        current_frame = inspect.currentframe()
        assert current_frame is not None
        caller_module = inspect.getmodule(current_frame.f_back)
Tom Birch's avatar
Tom Birch committed
387
388
389
390
        setattr(caller_module, f"test_{name}", replacement)

        return func

391
    return prepare_test
392
393


394
class _Block(Base):
395
396
397
398
399
    def __init__(self, embed_dim: int, num_heads: int) -> None:
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)  # type: ignore
400
401
402
403
404
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )
405

406
    def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
407
408
409
410
411
412
413
414
415
416
417
418
        x = inputs[0]
        attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
        attn_mask = torch.triu(attn_mask, diagonal=1)

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


419
class GPT2(Base):
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
420
421
422
423
    """
    GPT2 pytorch implementation, for testing purposes in the image-GPT context
    Credits: https://github.com/teddykoker/image-gpt"""

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    def __init__(
        self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int
    ) -> None:
        super().__init__()

        self.embed_dim = embed_dim

        # start of sequence token
        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
        self.position_embeddings = nn.Embedding(num_positions, embed_dim)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(_Block(embed_dim, num_heads))

        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_vocab, bias=False)
        self.clf_head = nn.Linear(embed_dim, num_classes)

446
    def forward(self, x: Tensor, classify: bool = False) -> Any:  # type: ignore
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        """
        Expect input as shape [sequence len, batch]
        If classify, return classification logits
        """
        length, batch = x.shape

        h = self.token_embeddings(x)

        # prepend sos token
        sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
        h = torch.cat([sos, h[:-1, :, :]], dim=0)

        # add positional embeddings
        positions = torch.arange(length, device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)

        # transformer
        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)

        logits = self.head(h)

        if not classify:
            # return logits
            return logits

        h = torch.mean(h, dim=0)  # average pool over sequence
        # return classification logits and generative logits
        return self.clf_head(h), logits
Myle Ott's avatar
Myle Ott committed
478
479


480
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
Myle Ott's avatar
Myle Ott committed
481
482
483
484
485
    """
    Test that two objects are equal. Tensors are compared to ensure matching
    size, dtype, device and values.
    """
    if type(a) is not type(b):
486
487
        if raise_exception:
            raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
Myle Ott's avatar
Myle Ott committed
488
489
490
        return False
    if isinstance(a, dict):
        if set(a.keys()) != set(b.keys()):
491
492
            if raise_exception:
                raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
Myle Ott's avatar
Myle Ott committed
493
494
            return False
        for k in a.keys():
495
            if not objects_are_equal(a[k], b[k], raise_exception, k):
Myle Ott's avatar
Myle Ott committed
496
497
498
499
                return False
        return True
    elif isinstance(a, (list, tuple, set)):
        if len(a) != len(b):
500
501
            if raise_exception:
                raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
Myle Ott's avatar
Myle Ott committed
502
503
504
505
506
            return False
        return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
    elif torch.is_tensor(a):
        try:
            # assert_allclose doesn't strictly test shape, dtype and device
507
            shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
508
509
510
511
512
513
514
515
516
517
            if not shape_dtype_device_match:
                if raise_exception:
                    msg = f"sizes: {a.size()} vs. {b.size()}, "
                    msg += f"types: {a.dtype} vs. {b.dtype}, "
                    msg += f"device: {a.device} vs. {b.device}"
                    raise AssertionError(msg)
                else:
                    return False
            # assert_allclose.
            torch.testing.assert_allclose(a, b)
Myle Ott's avatar
Myle Ott committed
518
            return True
519
        except (AssertionError, RuntimeError) as e:
Myle Ott's avatar
Myle Ott committed
520
            if raise_exception:
521
522
523
524
                if dict_key and isinstance(e, AssertionError):
                    # Add dict key to the assertion error.
                    msg = e.args[0]
                    new_msg = f"For dict key '{dict_key}': {msg}"
525
                    raise AssertionError(new_msg) from None
526
527
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
528
529
530
531
            else:
                return False
    else:
        return a == b
532
533
534
535
536
537
538
539


def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
    for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
        assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message

    for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
        assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555


def check_same_models_across_ranks(
    model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
    world_size = dist.get_world_size(process_group)
    rank = dist.get_rank(process_group)
    for param in model.parameters():
        # collect the params across the rank
        receptacle = [param.clone() for _ in range(world_size)]
        dist.all_gather(receptacle, param, group=process_group)

        if rank == 0:
            for sync_p in receptacle[1:]:
                assert not params_should_be_equal or torch.all(
                    torch.eq(receptacle[0], sync_p)
556
                ), f"Models differ in between ranks {receptacle[0]} - {sync_p}"
557
558
559
560
561
562
563
564
565
566

    # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
    if check_broadcast_buffers:
        for buffer in model.buffers():
            receptacle = [buffer.clone() for _ in range(world_size)]
            dist.all_gather(receptacle, buffer, group=process_group)
            if rank == 0:
                for sync_b in receptacle[1:]:
                    assert not params_should_be_equal or torch.all(
                        torch.eq(receptacle[0], sync_b)
567
                    ), f"Models differ in between ranks {receptacle[0]} - {sync_b}"
568
569
570
571
572
573
574
575
576
577
578
579
580


class DeviceAndTypeCheckModule(Base):
    """A simple module for checking Tensor devices and dtypes."""

    def __init__(
        self,
        expected_input_dtype: Optional[torch.dtype] = None,
        expected_input_device: Optional[torch.device] = None,
        expected_param_dtype: Optional[torch.dtype] = None,
        expected_param_device: Optional[torch.device] = None,
        expected_loss_dtype: Optional[torch.dtype] = None,
        expected_loss_device: Optional[torch.device] = None,
581
        expected_buffer_dtype: Optional[torch.device] = None,
582
583
584
585
586
587
588
589
    ):
        super().__init__()
        self.expected_input_dtype = expected_input_dtype
        self.expected_input_device = expected_input_device
        self.expected_param_dtype = expected_param_dtype
        self.expected_param_device = expected_param_device
        self.expected_loss_dtype = expected_loss_dtype
        self.expected_loss_device = expected_loss_device
590
        self.expected_buffer_dtype = expected_buffer_dtype
591
592

        self.linear = nn.Linear(5, 5)
593
        self.register_buffer("buffer", torch.rand((5,)))
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

    def _check(
        self,
        key: str,
        x: Union[torch.device, torch.dtype],
        expected: Union[Optional[torch.device], Optional[torch.dtype]],
    ) -> None:
        assert expected in {None, x}, f"{key} ({x}) != expected ({expected})"

    def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
        x = input[0]
        self._check("input.dtype", x.dtype, self.expected_input_dtype)
        self._check("input.device", x.device, self.expected_input_device)

        param = self.linear.weight
        self._check("param.dtype", param.dtype, self.expected_param_dtype)
        self._check("param.device", param.device, self.expected_param_device)
611
612
613
        self._check("buffer.dtype", self.buffer.dtype, self.expected_buffer_dtype)  # type: ignore
        x = x + self.buffer
        loss = (self.linear(x) + self.buffer).sum()
614
615
616
617
618
619
620
621
        self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
        self._check("loss.device", loss.device, self.expected_loss_device)

        return loss


@functools.lru_cache()
def get_cycles_per_ms() -> float:
622
    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
623
624
625

    Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
    """
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650

    def measure() -> float:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        torch.cuda._sleep(1000000)
        end.record()
        end.synchronize()
        cycles_per_ms = 1000000 / start.elapsed_time(end)
        return cycles_per_ms

    # Get 10 values and remove the 2 max and 2 min and return the avg.
    # This is to avoid system disturbance that skew the results, e.g.
    # the very first cuda call likely does a bunch of init, which takes
    # much longer than subsequent calls.
    #
    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
    # and seems to return stable values. Therefore, we enable caching
    # using lru_cache decorator above.
    num = 10
    vals = []
    for _ in range(num):
        vals.append(measure())
    vals = sorted(vals)
    return mean(vals[2 : num - 2])
651
652
653
654
655
656
657
658
659
660
661
662


class DummyProcessGroup:
    def __init__(self, rank: int, size: int):
        self._rank = rank
        self._size = size

    def rank(self) -> int:
        return self._rank

    def size(self) -> int:
        return self._size
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686


class SGDWithPausingCompute(torch.optim.SGD):
    def __init__(self, *args, **kwargs) -> None:  # type: ignore
        self.rank = kwargs["rank"]
        del kwargs["rank"]

        super().__init__(*args, **kwargs)

    def step(self, closure: Optional[Any] = None) -> Any:
        loss = super().step(closure=closure)

        # This is used to make sure that OSS and ShardedDDP enforce a proper stream synchronization
        # - Add a long cuda wait on a compute stream, non blocking from the CPU perspective
        with torch.cuda.stream(torch.cuda.Stream()):
            torch.cuda._sleep(100000000)

            # - optionally change the params on a per rank basis
            with torch.no_grad():
                for param_group in self.param_groups:
                    for param in param_group["params"]:
                        param *= 1.0 + self.rank / 10.0

        return loss
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704


def state_dict_norm(state: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Compute the norm from a state_dict for simple comparison."""
    norm = torch.zeros(1)
    for v in state.values():
        if not v.is_floating_point():
            v = v.float()
        norm += v.norm()
    return norm


def rmf(filename: str) -> None:
    """Remove a file like rm -f."""
    try:
        os.remove(filename)
    except FileNotFoundError:
        pass
705
706


707
708
709
710
711
712
713
714
715
@contextlib.contextmanager
def in_temporary_directory() -> Generator:
    """
    Context manager to create a temporary direction and remove
    it at the end of the context
    """
    old_cwd = os.getcwd()
    with tempfile.TemporaryDirectory() as temp_dir:
        os.chdir(temp_dir)
716
717
718
719
        try:
            yield temp_dir
        finally:
            os.chdir(old_cwd)
720
721


722
723
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
724
    """A context to get tempfiles and ensure they are cleaned up."""
725
726
    files = [tempfile.mkstemp()[1] for _ in range(num)]

727
728
729
730
731
732
    try:
        yield tuple(files)
    finally:
        # temp files could have been removed, so we use rmf.
        for name in files:
            rmf(name)
733
734
735
736
737
738
739
740
741
742
743


def dump_all_tensors(rank: int) -> None:
    """Useful tool for debugging memory issues from the python side."""
    if rank != 0:
        return
    for obj in gc.get_objects():
        try:
            ttype = str(type(obj))
            if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
                print(ttype, obj.shape, obj.dtype, obj.device, obj.storage().size())
744
        except Exception:
745
746
            pass
    print(torch.cuda.memory_summary())
747
748
749
750
751
752
753
754
755
756
757
758


def get_smi_memory() -> float:
    """Return process's GPU memory in MB."""
    pid = os.getpid()
    info_string = torch.cuda.list_gpu_processes()
    for line in info_string.splitlines():
        if str(pid) in line:
            toks = line.split()
            return float(toks[3])
    # If the process is not in the list, we are not using the GPU.
    return 0.0
759
760
761
762
763
764


def skip_a_test_if_in_CI() -> None:
    """Skip a test in circle CI"""
    if os.path.exists("/home/circleci"):
        pytest.skip("Sometimes a CI test failure is not reproducible locally, we skip them")