Unverified Commit 6bb244f2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add train_step() and val_step() for MMDP (#354)

parent c74d729d
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from itertools import chain
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -8,3 +10,41 @@ class MMDataParallel(DataParallel): ...@@ -8,3 +10,41 @@ class MMDataParallel(DataParallel):
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
return self.module.train_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.train_step(*inputs[0], **kwargs[0])
def val_step(self, *inputs, **kwargs):
if not self.device_ids:
return self.module.val_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.val_step(*inputs[0], **kwargs[0])
...@@ -113,6 +113,9 @@ def test_runner_with_parallel(): ...@@ -113,6 +113,9 @@ def test_runner_with_parallel():
model = MMDataParallel(OldStyleModel()) model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger()) _ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger())
model = MMDataParallel(Model())
_ = EpochBasedRunner(model, logger=logging.getLogger())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set # batch_processor and train_step() cannot be both set
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment