"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b2b3b1a8ab83b020ecaf32f45de3ef23644331cf"
Commit c211ab13 authored by Kai Chen's avatar Kai Chen
Browse files

use attributes to check if a model is a DataParallel or DistributedDataParallel

parent 2d231c2b
...@@ -4,7 +4,6 @@ from collections import OrderedDict ...@@ -4,7 +4,6 @@ from collections import OrderedDict
import mmcv import mmcv
import torch import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils import model_zoo from torch.utils import model_zoo
...@@ -102,7 +101,7 @@ def load_checkpoint(model, ...@@ -102,7 +101,7 @@ def load_checkpoint(model,
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict # load state_dict
if isinstance(model, (DataParallel, DistributedDataParallel)): if hasattr(model, 'module'):
load_state_dict(model.module, state_dict, strict, logger) load_state_dict(model.module, state_dict, strict, logger)
else: else:
load_state_dict(model, state_dict, strict, logger) load_state_dict(model, state_dict, strict, logger)
...@@ -144,7 +143,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -144,7 +143,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename)) mmcv.mkdir_or_exist(osp.dirname(filename))
if isinstance(model, (DataParallel, DistributedDataParallel)): if hasattr(model, 'module'):
model = model.module model = model.module
checkpoint = { checkpoint = {
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import mmcv import mmcv
import torch import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .. import hooks from .. import hooks
...@@ -42,7 +41,7 @@ class Runner(object): ...@@ -42,7 +41,7 @@ class Runner(object):
raise TypeError('"work_dir" must be a str or None') raise TypeError('"work_dir" must be a str or None')
# get model name from the model class # get model name from the model class
if isinstance(self.model, (DataParallel, DistributedDataParallel)): if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__ self._model_name = self.model.module.__class__.__name__
else: else:
self._model_name = self.model.__class__.__name__ self._model_name = self.model.__class__.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment