# coding=utf-8 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import numpy import torch import torch.distributed as dist import torch.multiprocessing as mp from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed class IdentityLayer(torch.nn.Module): def __init__(self, size, scale=1.0): super(IdentityLayer, self).__init__() self.weight = torch.nn.Parameter(scale * torch.randn(size)) def forward(self): return self.weight def set_random_seed(seed): """Set random seed for reproducability.""" random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) def dist_init(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def get_world_sizes(): limit = torch.cuda.device_count() return [x for x in [1, 2, 4, 8] if x <= limit] def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes()): for world_size in world_sizes: mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True)