import torch from ._functions import Scatter from torch.nn.parallel._functions import Scatter as OrigScatter from detkit.datasets.utils import DataContainer def scatter(inputs, target_gpus, dim=0): """Scatter inputs to target gpus. The only difference from original :func:`scatter` is to add support for :type:`~mmdet.DataContainer`. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): return OrigScatter.apply(target_gpus, None, dim, obj) if isinstance(obj, DataContainer) and isinstance(obj.data, list): return Scatter.forward(target_gpus, obj.data) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): """Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs