Commit da7bb063 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by Kai Chen
Browse files

Load state dict (#164)

* use load_from_state_dict to load ckpt

* reformat

* reformat

* reformat import

* pass flake8

* isort skip
parent b34eef6e
# flake8: noqa
from .arraymisc import *
from .utils import *
from .fileio import *
from .opencv_info import *
from .image import *
from .opencv_info import *
from .utils import *
from .version import __version__
from .video import *
from .visualization import *
from .version import __version__
# The following modules are not imported to this level, so mmcv may be used
# without PyTorch.
# - runner
......
from .quantization import quantize, dequantize
from .quantization import dequantize, quantize
__all__ = ['quantize', 'dequantize']
from .alexnet import AlexNet
from .vgg import VGG, make_vgg_layer
from .resnet import ResNet, make_res_layer
from .weight_init import (constant_init, xavier_init, normal_init,
uniform_init, kaiming_init, caffe2_xavier_init)
from .vgg import VGG, make_vgg_layer
from .weight_init import (caffe2_xavier_init, constant_init, kaiming_init,
normal_init, uniform_init, xavier_init)
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
......
from .io import load, dump, register_handler
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .parse import list_from_file, dict_from_file
from .io import dump, load, register_handler
from .parse import dict_from_file, list_from_file
__all__ = [
'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler',
......
from .io import imread, imwrite, imfrombytes
from .transforms import (solarize, posterize, bgr2gray, rgb2gray, gray2bgr,
gray2rgb, bgr2rgb, rgb2bgr, bgr2hsv, hsv2bgr, bgr2hls,
hls2bgr, iminvert, imflip, imrotate, imcrop, impad,
impad_to_multiple, imnormalize, imdenormalize,
imresize, imresize_like, imrescale)
from .io import imfrombytes, imread, imwrite
from .transforms import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
gray2rgb, hls2bgr, hsv2bgr, imcrop, imdenormalize,
imflip, iminvert, imnormalize, impad,
impad_to_multiple, imrescale, imresize, imresize_like,
imrotate, posterize, rgb2bgr, rgb2gray, solarize)
__all__ = [
'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray',
......
from .colorspace import (solarize, posterize, bgr2gray, rgb2gray, gray2bgr,
gray2rgb, bgr2rgb, rgb2bgr, bgr2hsv, hsv2bgr, bgr2hls,
hls2bgr, iminvert)
from .geometry import imflip, imrotate, imcrop, impad, impad_to_multiple
from .normalize import imnormalize, imdenormalize
from .resize import imresize, imresize_like, imrescale
from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
gray2rgb, hls2bgr, hsv2bgr, iminvert, posterize,
rgb2bgr, rgb2gray, solarize)
from .geometry import imcrop, imflip, impad, impad_to_multiple, imrotate
from .normalize import imdenormalize, imnormalize
from .resize import imrescale, imresize, imresize_like
__all__ = [
'solarize', 'posterize', 'bgr2gray', 'rgb2gray', 'gray2bgr', 'gray2rgb',
......
from .runner import Runner
from .checkpoint import (load_checkpoint, load_state_dict, save_checkpoint,
weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook,
PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .log_buffer import LogBuffer
from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
OptimizerHook, IterTimerHook, DistSamplerSeedHook,
LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, WandbLoggerHook)
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel_test import parallel_test
from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict
from .dist_utils import init_dist, get_dist_info, master_only
__all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
......
......@@ -8,7 +8,6 @@ from importlib import import_module
import torch
import torchvision
from terminaltables import AsciiTable
from torch.utils import model_zoo
import mmcv
......@@ -56,42 +55,39 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
message. If not specified, print function will be used.
"""
unexpected_keys = []
shape_mismatch_pairs = []
all_missing_keys = []
err_msg = []
own_state = module.state_dict()
for name, param in state_dict.items():
if name not in own_state:
unexpected_keys.append(name)
continue
if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
if param.size() != own_state[name].size():
shape_mismatch_pairs.append(
[name, own_state[name].size(),
param.size()])
continue
own_state[name].copy_(param)
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
err_msg = []
if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format(
', '.join(unexpected_keys)))
if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys)))
if shape_mismatch_pairs:
mismatch_info = 'these keys have mismatched shape:\n'
header = ['key', 'expected shape', 'loaded shape']
table_data = [header] + shape_mismatch_pairs
table = AsciiTable(table_data)
err_msg.append(mismatch_info + table.table)
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
......
from .hook import Hook
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .hook import Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .optimizer import OptimizerHook
from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook, WandbLoggerHook)
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
......
import numbers
from ...dist_utils import master_only
from .base import LoggerHook
import numbers
class WandbLoggerHook(LoggerHook):
......
from .config import ConfigDict, Config
from .misc import (is_str, iter_cast, list_cast, tuple_cast, is_seq_of,
is_list_of, is_tuple_of, slice_list, concat_list,
check_prerequisites, requires_package, requires_executable)
from .path import (is_filepath, fopen, check_file_exist, mkdir_or_exist,
symlink, scandir, FileNotFoundError)
from .progressbar import (ProgressBar, track_progress, track_parallel_progress,
track_iter_progress)
from .config import Config, ConfigDict
from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast,
requires_executable, requires_package, slice_list,
tuple_cast)
from .path import (FileNotFoundError, check_file_exist, fopen, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .timer import Timer, TimerError, check_time
__all__ = [
......
from .io import Cache, VideoReader, frames2video
from .processing import convert_video, resize_video, cut_video, concat_video
from .optflow import (flowread, flowwrite, quantize_flow, dequantize_flow,
flow_warp)
from .optflow import (dequantize_flow, flow_warp, flowread, flowwrite,
quantize_flow)
from .processing import concat_video, convert_video, cut_video, resize_video
__all__ = [
'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
......
from .color import Color, color_val
from .image import imshow, imshow_bboxes, imshow_det_bboxes
from .optflow import flowshow, flow2rgb, make_color_wheel
from .optflow import flow2rgb, flowshow, make_color_wheel
__all__ = [
'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
......
import sys
import time
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import pytest
import pytest # isort:skip
import mmcv
import mmcv # isort:skip
def reset_string_io(io):
......
import os.path as osp
import tempfile
import warnings
from mock import MagicMock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment