"src/vscode:/vscode.git/clone" did not exist on "f67639b0bb54d3ccf7fc17157ba0b1e2e959ac5e"
Commit c211ab13 authored by Kai Chen's avatar Kai Chen
Browse files

use attributes to check if a model is a DataParallel or DistributedDataParallel

parent 2d231c2b
......@@ -4,7 +4,6 @@ from collections import OrderedDict
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils import model_zoo
......@@ -102,7 +101,7 @@ def load_checkpoint(model,
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict
if isinstance(model, (DataParallel, DistributedDataParallel)):
if hasattr(model, 'module'):
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
......@@ -144,7 +143,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if isinstance(model, (DataParallel, DistributedDataParallel)):
if hasattr(model, 'module'):
model = model.module
checkpoint = {
......
......@@ -4,7 +4,6 @@ import time
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer
from .. import hooks
......@@ -42,7 +41,7 @@ class Runner(object):
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if isinstance(self.model, (DataParallel, DistributedDataParallel)):
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment