Commit da7bb063 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by Kai Chen
Browse files

Load state dict (#164)

* use load_from_state_dict to load ckpt

* reformat

* reformat

* reformat import

* pass flake8

* isort skip
parent b34eef6e
# flake8: noqa # flake8: noqa
from .arraymisc import * from .arraymisc import *
from .utils import *
from .fileio import * from .fileio import *
from .opencv_info import *
from .image import * from .image import *
from .opencv_info import *
from .utils import *
from .version import __version__
from .video import * from .video import *
from .visualization import * from .visualization import *
from .version import __version__
# The following modules are not imported to this level, so mmcv may be used # The following modules are not imported to this level, so mmcv may be used
# without PyTorch. # without PyTorch.
# - runner # - runner
......
from .quantization import quantize, dequantize from .quantization import dequantize, quantize
__all__ = ['quantize', 'dequantize'] __all__ = ['quantize', 'dequantize']
from .alexnet import AlexNet from .alexnet import AlexNet
from .vgg import VGG, make_vgg_layer
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .weight_init import (constant_init, xavier_init, normal_init, from .vgg import VGG, make_vgg_layer
uniform_init, kaiming_init, caffe2_xavier_init) from .weight_init import (caffe2_xavier_init, constant_init, kaiming_init,
normal_init, uniform_init, xavier_init)
__all__ = [ __all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
......
from .io import load, dump, register_handler
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .parse import list_from_file, dict_from_file from .io import dump, load, register_handler
from .parse import dict_from_file, list_from_file
__all__ = [ __all__ = [
'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', 'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler',
......
from .io import imread, imwrite, imfrombytes from .io import imfrombytes, imread, imwrite
from .transforms import (solarize, posterize, bgr2gray, rgb2gray, gray2bgr, from .transforms import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
gray2rgb, bgr2rgb, rgb2bgr, bgr2hsv, hsv2bgr, bgr2hls, gray2rgb, hls2bgr, hsv2bgr, imcrop, imdenormalize,
hls2bgr, iminvert, imflip, imrotate, imcrop, impad, imflip, iminvert, imnormalize, impad,
impad_to_multiple, imnormalize, imdenormalize, impad_to_multiple, imrescale, imresize, imresize_like,
imresize, imresize_like, imrescale) imrotate, posterize, rgb2bgr, rgb2gray, solarize)
__all__ = [ __all__ = [
'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray', 'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray',
......
from .colorspace import (solarize, posterize, bgr2gray, rgb2gray, gray2bgr, from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
gray2rgb, bgr2rgb, rgb2bgr, bgr2hsv, hsv2bgr, bgr2hls, gray2rgb, hls2bgr, hsv2bgr, iminvert, posterize,
hls2bgr, iminvert) rgb2bgr, rgb2gray, solarize)
from .geometry import imflip, imrotate, imcrop, impad, impad_to_multiple from .geometry import imcrop, imflip, impad, impad_to_multiple, imrotate
from .normalize import imnormalize, imdenormalize from .normalize import imdenormalize, imnormalize
from .resize import imresize, imresize_like, imrescale from .resize import imrescale, imresize, imresize_like
__all__ = [ __all__ = [
'solarize', 'posterize', 'bgr2gray', 'rgb2gray', 'gray2bgr', 'gray2rgb', 'solarize', 'posterize', 'bgr2gray', 'rgb2gray', 'gray2bgr', 'gray2rgb',
......
from .runner import Runner from .checkpoint import (load_checkpoint, load_state_dict, save_checkpoint,
weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook,
PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
OptimizerHook, IterTimerHook, DistSamplerSeedHook,
LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, WandbLoggerHook)
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel_test import parallel_test from .parallel_test import parallel_test
from .priority import Priority, get_priority from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
from .dist_utils import init_dist, get_dist_info, master_only
__all__ = [ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
......
...@@ -8,7 +8,6 @@ from importlib import import_module ...@@ -8,7 +8,6 @@ from importlib import import_module
import torch import torch
import torchvision import torchvision
from terminaltables import AsciiTable
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
...@@ -56,42 +55,39 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -56,42 +55,39 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
message. If not specified, print function will be used. message. If not specified, print function will be used.
""" """
unexpected_keys = [] unexpected_keys = []
shape_mismatch_pairs = [] all_missing_keys = []
err_msg = []
own_state = module.state_dict() metadata = getattr(state_dict, '_metadata', None)
for name, param in state_dict.items(): state_dict = state_dict.copy()
if name not in own_state: if metadata is not None:
unexpected_keys.append(name) state_dict._metadata = metadata
continue
if isinstance(param, torch.nn.Parameter): # use _load_from_state_dict to enable checkpoint version control
# backwards compatibility for serialized parameters def load(module, prefix=''):
param = param.data local_metadata = {} if metadata is None else metadata.get(
if param.size() != own_state[name].size(): prefix[:-1], {})
shape_mismatch_pairs.append( module._load_from_state_dict(state_dict, prefix, local_metadata, True,
[name, own_state[name].size(), all_missing_keys, unexpected_keys,
param.size()]) err_msg)
continue for name, child in module._modules.items():
own_state[name].copy_(param) if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
# ignore "num_batches_tracked" of BN layers # ignore "num_batches_tracked" of BN layers
missing_keys = [ missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key key for key in all_missing_keys if 'num_batches_tracked' not in key
] ]
err_msg = []
if unexpected_keys: if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format( err_msg.append('unexpected key in source state_dict: {}\n'.format(
', '.join(unexpected_keys))) ', '.join(unexpected_keys)))
if missing_keys: if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format( err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys))) ', '.join(missing_keys)))
if shape_mismatch_pairs:
mismatch_info = 'these keys have mismatched shape:\n'
header = ['key', 'expected shape', 'loaded shape']
table_data = [header] + shape_mismatch_pairs
table = AsciiTable(table_data)
err_msg.append(mismatch_info + table.table)
rank, _ = get_dist_info() rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0: if len(err_msg) > 0 and rank == 0:
......
from .hook import Hook
from .checkpoint import CheckpointHook from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .hook import Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .optimizer import OptimizerHook from .optimizer import OptimizerHook
from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, WandbLoggerHook)
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
......
import numbers
from ...dist_utils import master_only from ...dist_utils import master_only
from .base import LoggerHook from .base import LoggerHook
import numbers
class WandbLoggerHook(LoggerHook): class WandbLoggerHook(LoggerHook):
......
from .config import ConfigDict, Config from .config import Config, ConfigDict
from .misc import (is_str, iter_cast, list_cast, tuple_cast, is_seq_of, from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_list_of, is_tuple_of, slice_list, concat_list, is_str, is_tuple_of, iter_cast, list_cast,
check_prerequisites, requires_package, requires_executable) requires_executable, requires_package, slice_list,
from .path import (is_filepath, fopen, check_file_exist, mkdir_or_exist, tuple_cast)
symlink, scandir, FileNotFoundError) from .path import (FileNotFoundError, check_file_exist, fopen, is_filepath,
from .progressbar import (ProgressBar, track_progress, track_parallel_progress, mkdir_or_exist, scandir, symlink)
track_iter_progress) from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .timer import Timer, TimerError, check_time from .timer import Timer, TimerError, check_time
__all__ = [ __all__ = [
......
from .io import Cache, VideoReader, frames2video from .io import Cache, VideoReader, frames2video
from .processing import convert_video, resize_video, cut_video, concat_video from .optflow import (dequantize_flow, flow_warp, flowread, flowwrite,
from .optflow import (flowread, flowwrite, quantize_flow, dequantize_flow, quantize_flow)
flow_warp) from .processing import concat_video, convert_video, cut_video, resize_video
__all__ = [ __all__ = [
'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video', 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
......
from .color import Color, color_val from .color import Color, color_val
from .image import imshow, imshow_bboxes, imshow_det_bboxes from .image import imshow, imshow_bboxes, imshow_det_bboxes
from .optflow import flowshow, flow2rgb, make_color_wheel from .optflow import flow2rgb, flowshow, make_color_wheel
__all__ = [ __all__ = [
'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes', 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
......
import sys import sys
import time import time
try: try:
from StringIO import StringIO from StringIO import StringIO
except ImportError: except ImportError:
from io import StringIO from io import StringIO
import pytest import pytest # isort:skip
import mmcv import mmcv # isort:skip
def reset_string_io(io): def reset_string_io(io):
......
import os.path as osp import os.path as osp
import tempfile import tempfile
import warnings import warnings
from mock import MagicMock from mock import MagicMock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment