Commit d400ac08 authored by Kai Chen's avatar Kai Chen
Browse files

add parallel module

parent 47fc7a69
......@@ -7,4 +7,7 @@ from .image import *
from .video import *
from .visualization import *
from .version import __version__
# runner is not imported here, so mmcv may be used without PyTorch
# The following modules are not imported to this level, so mmcv may be used
# without PyTorch:
# - runner
# - parallel
from .data_container import DataContainer
from .data_parallel import MMDataParallel
from .distributed import MMDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs
__all__ = [
'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel', 'scatter',
'scatter_kwargs'
]
import torch
from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None):
"""Scatters tensor across multiple GPUs.
"""
if streams is None:
streams = [None] * len(devices)
if isinstance(input, list):
chunk_size = (len(input) - 1) // len(devices) + 1
outputs = [
scatter(input[i], [devices[i // chunk_size]],
[streams[i // chunk_size]]) for i in range(len(input))
]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
return output
else:
raise Exception('Unknown type {}.'.format(type(input)))
def synchronize_stream(output, devices, streams):
if isinstance(output, list):
chunk_size = len(output) // len(devices)
for i in range(len(devices)):
for j in range(chunk_size):
synchronize_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]])
elif isinstance(output, torch.Tensor):
if output.numel() != 0:
with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[0])
output.record_stream(main_stream)
else:
raise Exception('Unknown type {}.'.format(type(output)))
def get_input_device(input):
if isinstance(input, list):
for item in input:
input_device = get_input_device(item)
if input_device != -1:
return input_device
return -1
elif isinstance(input, torch.Tensor):
return input.get_device() if input.is_cuda else -1
else:
raise Exception('Unknown type {}.'.format(type(input)))
class Scatter(object):
@staticmethod
def forward(target_gpus, input):
input_device = get_input_device(input)
streams = None
if input_device == -1:
# Perform CPU to GPU copies in a background stream
streams = [_get_stream(device) for device in target_gpus]
outputs = scatter(input, target_gpus, streams)
# Synchronize with the copy stream
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs)
import functools
import torch
def assert_tensor_type(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError('{} has no attribute {} for type {}'.format(
args[0].__class__.__name__, func.__name__, args[0].datatype))
return func(*args, **kwargs)
return wrapper
class DataContainer(object):
"""A container for any type of objects.
Typically tensors will be stacked in the collate function and sliced along
some dimension in the scatter function. This behavior has some limitations.
1. All tensors have to be the same size.
2. Types are limited (numpy array or Tensor).
We design `DataContainer` and `MMDataParallel` to overcome these
limitations. The behavior can be either of the following.
- copy to GPU, pad all tensors to the same size and stack them
- copy to GPU without stacking
- leave the objects as is and pass it to the model
"""
def __init__(self, data, stack=False, padding_value=0, cpu_only=False):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, repr(self.data))
@property
def data(self):
return self._data
@property
def datatype(self):
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
return self._cpu_only
@property
def stack(self):
return self._stack
@property
def padding_value(self):
return self._padding_value
@assert_tensor_type
def size(self, *args, **kwargs):
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
return self.data.dim()
from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs
class MMDataParallel(DataParallel):
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(nn.Module):
def __init__(self, module, dim=0, broadcast_buffers=True,
bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors, buffer_size):
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
for tensor, synced in zip(
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def _sync_params(self):
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
buffers = [b.data for b in self.module._all_buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
import torch
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
def scatter(inputs, target_gpus, dim=0):
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return OrigScatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_gpus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
......@@ -6,7 +6,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
TensorboardLoggerHook)
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel import parallel_test, worker_func
from .parallel_test import parallel_test
from .priority import Priority, get_priority
from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
obj_from_dict)
......@@ -16,7 +16,6 @@ __all__ = [
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'parallel_test', 'worker_func', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
'parallel_test', 'Priority', 'get_priority', 'get_host_info',
'get_dist_info', 'master_only', 'get_time_str', 'obj_from_dict'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment