Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
6bb244f2
Unverified
Commit
6bb244f2
authored
Jun 18, 2020
by
Kai Chen
Committed by
GitHub
Jun 18, 2020
Browse files
add train_step() and val_step() for MMDP (#354)
parent
c74d729d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
0 deletions
+43
-0
mmcv/parallel/data_parallel.py
mmcv/parallel/data_parallel.py
+40
-0
tests/test_runner/test_runner.py
tests/test_runner/test_runner.py
+3
-0
No files found.
mmcv/parallel/data_parallel.py
View file @
6bb244f2
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
from
itertools
import
chain
from
torch.nn.parallel
import
DataParallel
from
torch.nn.parallel
import
DataParallel
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
scatter_kwargs
...
@@ -8,3 +10,41 @@ class MMDataParallel(DataParallel):
...
@@ -8,3 +10,41 @@ class MMDataParallel(DataParallel):
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
if
not
self
.
device_ids
:
return
self
.
module
.
train_step
(
*
inputs
,
**
kwargs
)
assert
len
(
self
.
device_ids
)
==
1
,
\
(
'MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.'
)
for
t
in
chain
(
self
.
module
.
parameters
(),
self
.
module
.
buffers
()):
if
t
.
device
!=
self
.
src_device_obj
:
raise
RuntimeError
(
'module must have its parameters and buffers '
f
'on device
{
self
.
src_device_obj
}
(device_ids[0]) but '
f
'found one of them on device:
{
t
.
device
}
'
)
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
return
self
.
module
.
train_step
(
*
inputs
[
0
],
**
kwargs
[
0
])
def
val_step
(
self
,
*
inputs
,
**
kwargs
):
if
not
self
.
device_ids
:
return
self
.
module
.
val_step
(
*
inputs
,
**
kwargs
)
assert
len
(
self
.
device_ids
)
==
1
,
\
(
'MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.'
)
for
t
in
chain
(
self
.
module
.
parameters
(),
self
.
module
.
buffers
()):
if
t
.
device
!=
self
.
src_device_obj
:
raise
RuntimeError
(
'module must have its parameters and buffers '
f
'on device
{
self
.
src_device_obj
}
(device_ids[0]) but '
f
'found one of them on device:
{
t
.
device
}
'
)
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
return
self
.
module
.
val_step
(
*
inputs
[
0
],
**
kwargs
[
0
])
tests/test_runner/test_runner.py
View file @
6bb244f2
...
@@ -113,6 +113,9 @@ def test_runner_with_parallel():
...
@@ -113,6 +113,9 @@ def test_runner_with_parallel():
model
=
MMDataParallel
(
OldStyleModel
())
model
=
MMDataParallel
(
OldStyleModel
())
_
=
EpochBasedRunner
(
model
,
batch_processor
,
logger
=
logging
.
getLogger
())
_
=
EpochBasedRunner
(
model
,
batch_processor
,
logger
=
logging
.
getLogger
())
model
=
MMDataParallel
(
Model
())
_
=
EpochBasedRunner
(
model
,
logger
=
logging
.
getLogger
())
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
# batch_processor and train_step() cannot be both set
# batch_processor and train_step() cannot be both set
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment