Unverified Commit b57825f3 authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix] fix train example (#1502)

* [Fix] fix train example

* [Fix] fix a detail in train example and add warning in MMDP

* [Fix] fix docstring

* [Fix] fix docstring
parent e85c43ab
......@@ -42,7 +42,9 @@ class Model(nn.Module):
if __name__ == '__main__':
model = Model()
if torch.cuda.is_available():
model = MMDataParallel(model.cuda())
# only use gpu:0 to train
# Solved issue https://github.com/open-mmlab/mmcv/issues/1470
model = MMDataParallel(model.cuda(), device_ids=[0])
# dataset and dataloader
transform = transforms.Compose([
......
......@@ -15,6 +15,14 @@ class MMDataParallel(DataParallel):
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
......@@ -54,7 +62,7 @@ class MMDataParallel(DataParallel):
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment