commons.py 5.11 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Tom Birch's avatar
Tom Birch committed
22
23
import functools
import inspect
Tom Birch's avatar
Tom Birch committed
24
25
26
27
import os
import random

import numpy
Tom Birch's avatar
Tom Birch committed
28
29
from packaging import version
import pytest
Tom Birch's avatar
Tom Birch committed
30
31
import torch
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
32
from torch.distributed import rpc
Tom Birch's avatar
Tom Birch committed
33
34
import torch.multiprocessing as mp

Tom Birch's avatar
Tom Birch committed
35
from fairscale.nn.model_parallel import initialize_model_parallel
Tom Birch's avatar
Tom Birch committed
36
37
38
39
40
41
42
43
44
45
46
47
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed


class IdentityLayer(torch.nn.Module):
    def __init__(self, size, scale=1.0):
        super(IdentityLayer, self).__init__()
        self.weight = torch.nn.Parameter(scale * torch.randn(size))

    def forward(self):
        return self.weight


Tom Birch's avatar
Tom Birch committed
48
def set_random_seed(seed: int) -> None:
Tom Birch's avatar
Tom Birch committed
49
50
51
52
53
54
55
    """Set random seed for reproducability."""
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    model_parallel_cuda_manual_seed(seed)


Tom Birch's avatar
Tom Birch committed
56
57
58
59
60
61
62
63
64
65
66
def dist_init(rank, world_size, hostname=None):
    if hostname is None:
        hostname = "localhost"
    print(f"dist init r={rank}, world={world_size}, host={hostname}")
    os.environ["MASTER_ADDR"] = hostname
    os.environ["MASTER_PORT"] = "10638"
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)

    if version.parse(torch.__version__).release >= (1, 6, 0):
        init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
67
68
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=init_method)
Tom Birch's avatar
Tom Birch committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        os.environ["MASTER_ADDR"] = hostname
        os.environ["MASTER_PORT"] = "10639"
        init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
        rpc.init_rpc(
            f"Test{rank}",
            rank=rank,
            world_size=world_size,
            backend=rpc.BackendType.TENSORPIPE,
            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
        )
    else:
        if world_size > 1:
            rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
        else:
            torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    if torch.cuda.is_available() and torch.cuda.device_count():
        torch.cuda.set_device(rank % torch.cuda.device_count())


def get_worker_map():
    return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
Tom Birch's avatar
Tom Birch committed
91
92
93
94
95
96
97


def get_world_sizes():
    limit = torch.cuda.device_count()
    return [x for x in [1, 2, 4, 8] if x <= limit]


Tom Birch's avatar
Tom Birch committed
98
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]):
Tom Birch's avatar
Tom Birch committed
99
    for world_size in world_sizes:
Tom Birch's avatar
Tom Birch committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True)


def helper(rank, world_size, func, args):
    dist_init(rank, world_size)
    initialize_model_parallel(1, world_size)
    func(*args)


def torch_spawn(world_sizes=None):
    if world_sizes is None:
        world_sizes = get_world_sizes()

    def fixer(func):

        name = func.__name__
        parameters = inspect.signature(func).parameters

        if name.startswith("test"):
            raise ValueError(
                f"Tests marked with @torch_spawn (i.e. '{name}') should not have names beginning in 'test' as they will"
                " be picked up by pytest without running the spawn wrapper"
            )

        @functools.wraps(func)
        def replacement(*args, **kwargs):
            assert args == tuple()
            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            if "OMPI_COMM_WORLD_RANK" in os.environ:
                torch.distributed.init_process_group("mpi")
                world_size = torch.distributed.get_world_size()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
                if world_size in world_sizes:
                    func(*args)
                else:
                    pytest.skip(f"requested world size doesn't match current world size")
            else:
                spawn_for_all_world_sizes(helper, world_sizes, (func, args))

        caller_module = inspect.getmodule(inspect.currentframe().f_back)
        setattr(caller_module, f"test_{name}", replacement)

        return func

    return fixer