# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. import math import pickle import torch.distributed def distributed_init(args): if args.distributed_world_size == 1: raise ValueError('Cannot initialize distributed with distributed_world_size=1') print('| distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method), flush=True) if args.distributed_init_method.startswith('tcp://'): torch.distributed.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank) else: torch.distributed.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size) args.distributed_rank = torch.distributed.get_rank() if args.distributed_rank != 0: suppress_output() return args.distributed_rank def suppress_output(): """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): if 'force' in kwargs: force = kwargs.pop('force') if force: builtin_print(*args, **kwargs) __builtin__.print = print def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760): """All-reduce and rescale tensors in chunks of the specified size. Args: tensors: list of Tensors to all-reduce rescale_denom: denominator for rescaling summed Tensors buffer_size: all-reduce chunk size in bytes """ # buffer size is in bytes, determine equiv. # of elements based on data type buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() buffer = [] def all_reduce_buffer(): # copy tensors into buffer_t offset = 0 for t in buffer: numel = t.numel() buffer_t[offset:offset+numel].copy_(t.view(-1)) offset += numel # all-reduce and rescale torch.distributed.all_reduce(buffer_t[:offset]) buffer_t.div_(rescale_denom) # copy all-reduced buffer back into tensors offset = 0 for t in buffer: numel = t.numel() t.view(-1).copy_(buffer_t[offset:offset+numel]) offset += numel filled = 0 for t in tensors: sz = t.numel() * t.element_size() if sz > buffer_size: # tensor is bigger than buffer, all-reduce and rescale directly torch.distributed.all_reduce(t) t.div_(rescale_denom) elif filled + sz > buffer_size: # buffer is full, all-reduce and replace buffer with grad all_reduce_buffer() buffer = [t] filled = sz else: # add tensor to buffer buffer.append(t) filled += sz if len(buffer) > 0: all_reduce_buffer() def all_gather_list(data, max_size=4096): """Gathers arbitrary data from all nodes into a list.""" world_size = torch.distributed.get_world_size() if not hasattr(all_gather_list, '_in_buffer') or \ max_size != all_gather_list._in_buffer.size(): all_gather_list._in_buffer = torch.ByteTensor(max_size) all_gather_list._out_buffers = [ torch.cuda.ByteTensor(max_size) for i in range(world_size) ] in_buffer = all_gather_list._in_buffer out_buffers = all_gather_list._out_buffers enc = pickle.dumps(data) if len(enc) >= max_size: raise ValueError('encoded data exceeds max_size: {}'.format(len(enc))) in_buffer[0] = len(enc) in_buffer[1:len(enc)+1] = torch.ByteTensor(list(enc)) torch.distributed.all_gather(out_buffers, in_buffer.cuda()) result = [] for i in range(world_size): out_buffer = out_buffers[i] size = out_buffer[0] result.append( pickle.loads(bytes(out_buffer[1:size+1].tolist())) ) return result