Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
c211ab13
Commit
c211ab13
authored
Sep 25, 2018
by
Kai Chen
Browse files
use attributes to check if a model is a DataParallel or DistributedDataParallel
parent
2d231c2b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
5 deletions
+3
-5
mmcv/torchpack/io.py
mmcv/torchpack/io.py
+2
-3
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+1
-2
No files found.
mmcv/torchpack/io.py
View file @
c211ab13
...
@@ -4,7 +4,6 @@ from collections import OrderedDict
...
@@ -4,7 +4,6 @@ from collections import OrderedDict
import
mmcv
import
mmcv
import
torch
import
torch
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
torch.utils
import
model_zoo
from
torch.utils
import
model_zoo
...
@@ -102,7 +101,7 @@ def load_checkpoint(model,
...
@@ -102,7 +101,7 @@ def load_checkpoint(model,
if
list
(
state_dict
.
keys
())[
0
].
startswith
(
'module.'
):
if
list
(
state_dict
.
keys
())[
0
].
startswith
(
'module.'
):
state_dict
=
{
k
[
7
:]:
v
for
k
,
v
in
checkpoint
[
'state_dict'
].
items
()}
state_dict
=
{
k
[
7
:]:
v
for
k
,
v
in
checkpoint
[
'state_dict'
].
items
()}
# load state_dict
# load state_dict
if
isinstance
(
model
,
(
DataParallel
,
DistributedDataParallel
)
):
if
hasattr
(
model
,
'module'
):
load_state_dict
(
model
.
module
,
state_dict
,
strict
,
logger
)
load_state_dict
(
model
.
module
,
state_dict
,
strict
,
logger
)
else
:
else
:
load_state_dict
(
model
,
state_dict
,
strict
,
logger
)
load_state_dict
(
model
,
state_dict
,
strict
,
logger
)
...
@@ -144,7 +143,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
...
@@ -144,7 +143,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta
.
update
(
mmcv_version
=
mmcv
.
__version__
,
time
=
time
.
asctime
())
meta
.
update
(
mmcv_version
=
mmcv
.
__version__
,
time
=
time
.
asctime
())
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filename
))
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filename
))
if
isinstance
(
model
,
(
DataParallel
,
DistributedDataParallel
)
):
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
model
.
module
checkpoint
=
{
checkpoint
=
{
...
...
mmcv/torchpack/runner/runner.py
View file @
c211ab13
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
mmcv
import
mmcv
import
torch
import
torch
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
..
import
hooks
from
..
import
hooks
...
@@ -42,7 +41,7 @@ class Runner(object):
...
@@ -42,7 +41,7 @@ class Runner(object):
raise
TypeError
(
'"work_dir" must be a str or None'
)
raise
TypeError
(
'"work_dir" must be a str or None'
)
# get model name from the model class
# get model name from the model class
if
isinstance
(
self
.
model
,
(
DataParallel
,
DistributedDataParallel
)
):
if
hasattr
(
self
.
model
,
'module'
):
self
.
_model_name
=
self
.
model
.
module
.
__class__
.
__name__
self
.
_model_name
=
self
.
model
.
module
.
__class__
.
__name__
else
:
else
:
self
.
_model_name
=
self
.
model
.
__class__
.
__name__
self
.
_model_name
=
self
.
model
.
__class__
.
__name__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment