import torch from torch.autograd import Variable from torch.nn.parallel._functions import Scatter, Gather def scatter(inputs, target_gpus, dim=0, chunk_sizes=None): r""" Slices variables into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not variables. Does not support Tensors. """ def scatter_map(obj): if isinstance(obj, Variable): return Scatter.apply(target_gpus, chunk_sizes, dim, obj) assert not torch.is_tensor(obj), "Tensors not supported in scatter." if isinstance(obj, tuple): return list(zip(*map(scatter_map, obj))) if isinstance(obj, list): return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict): return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] return scatter_map(inputs) def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None): r"""Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else [] kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs