Commit b7968de7 authored by Kai Chen's avatar Kai Chen
Browse files

modify MMDistributedDataParallel, no longer inherited from DistributedDataParallel

parent e74c260f
from torch.nn.parallel import DistributedDataParallel import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel): class MMDistributedDataParallel(nn.Module):
def __init__(self, module, dim=0, broadcast_buffers=True):
super(MMDistributedDataParallel, self).__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.first_synced = False
self.broadcast_bucket_size = 32 * 1024 * 1024
def _dist_broadcast_coalesced(self, tensors, buffer_size):
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
for tensor, synced in zip(
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def sync_params(self):
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
buffers = [b.data for b in self.module._all_buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
if not self.first_synced:
self.sync_params()
self.first_synced = True
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
...@@ -2,8 +2,4 @@ ...@@ -2,8 +2,4 @@
PYTHON=${PYTHON:-"python"} PYTHON=${PYTHON:-"python"}
$PYTHON train.py $1 --dist --world-size $2 --rank 0 & $PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch
let MAX_RANK=$2-1 \ No newline at end of file
for i in `seq 1 $MAX_RANK`; do
$PYTHON train.py $1 --dist --world-size $2 --rank $i > /dev/null 2>&1 &
done
...@@ -95,10 +95,7 @@ def main(): ...@@ -95,10 +95,7 @@ def main():
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if dist: if dist:
model = MMDistributedDataParallel( model = MMDistributedDataParallel(model).cuda()
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False).cuda()
else: else:
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment