Unverified Commit 5a99f587 authored by WRH's avatar WRH Committed by GitHub
Browse files

[Feature]: support CPU training with MMDataParallel (#972)

* support for CPU training

* Update .pre-commit-config.yaml

* Update data_parallel.py
parent 841a078e
......@@ -49,7 +49,7 @@ class MMDataParallel(DataParallel):
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.train_step(*inputs, **kwargs)
return self.module.train_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
......@@ -71,7 +71,7 @@ class MMDataParallel(DataParallel):
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.val_step(*inputs, **kwargs)
return self.module.val_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment