##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang ## ECE Department, Rutgers University ## Email: zhang.hang@rutgers.edu ## Copyright (c) 2017 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import threading import torch import torch.cuda.nccl as nccl import torch.cuda.comm as comm from torch.autograd import Variable, Function from torch.nn.modules import Module from torch.nn.parallel.scatter_gather import scatter, scatter_kwargs, \ gather from torch.nn.parallel.replicate import replicate from torch.nn.parallel.parallel_apply import parallel_apply class ModelDataParallel(Module): """Implements data parallelism at the module level. .. ModelDataParallel_ This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. Note that the outputs are not gathered, please use compatible CriterionDataParallel_ . The batch size should be larger than the number of GPUs used. It should also be an integer multiple of the number of GPUs so that each chunk is the same size (so that each GPU processes the same number of samples). Args: module: module to be parallelized device_ids: CUDA devices (default: all devices) Example:: >>> net = torch.nn.ModelDataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) """ def __init__(self, module, device_ids=None, dim=0): super(ModelDataParallel, self).__init__() if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.dim = dim self.module = module self.device_ids = device_ids self.master_mean, self.master_var = {}, {} if len(self.device_ids) == 1: self.module.cuda(device_ids[0]) def forward(self, *inputs, **kwargs): inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, \ self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) return outputs def replicate(self, module, device_ids): return replicate(module, device_ids) def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs) class CriterionDataParallel(Module): """ .. CriterionDataParallel_ Calculate loss in multiple-GPUs, which balance the memory usage for Semantic Segmentation. The targets are splitted across the specified devices by chunking in the batch dimension. Please use together with ModelDataParallel_ """ def __init__(self, module, device_ids=None, output_device=None, dim=0): super(CriterionDataParallel, self).__init__() if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = device_ids self.output_device = output_device if len(self.device_ids) == 1: self.module.cuda(device_ids[0]) def forward(self, inputs, *targets, **kwargs): # input should be already scatterd # scattering the targets instead targets, kwargs = self.scatter(targets, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(inputs, *targets[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, targets, kwargs) return self.gather(outputs, self.output_device) def replicate(self, module, device_ids): return replicate(module, device_ids) def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def parallel_apply(self, replicas, inputs, targets, kwargs): return criterion_parallel_apply(replicas, inputs, targets, kwargs) def gather(self, outputs, output_device): return gather(outputs, output_device, dim=self.dim).mean()