# Copyright (c) OpenMMLab. All rights reserved. import io import os import os.path as osp import pkgutil import re import time import warnings from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory import torch import torchvision from torch.optim import Optimizer from torch.utils import model_zoo import mmcv from mmcv.parallel import is_module_wrapper from mmcv.runner.dist_utils import get_dist_info ENV_MMCV_HOME = 'MMCV_HOME' ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' DEFAULT_CACHE_DIR = '~/.cache' def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source ' f'state_dict: {", ".join(unexpected_keys)}\n') if missing_keys: err_msg.append( f'missing keys in source state_dict: {", ".join(missing_keys)}\n') rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) class CheckpointLoader: """A general checkpoint loader to manage all schemes.""" _schemes = {} @classmethod def _register_scheme(cls, prefixes, loader, force=False): if isinstance(prefixes, str): prefixes = [prefixes] else: assert isinstance(prefixes, (list, tuple)) for prefix in prefixes: if (prefix not in cls._schemes) or force: cls._schemes[prefix] = loader else: raise KeyError( f'{prefix} is already registered as a loader backend, ' 'add "force=True" if you want to override it') # sort, longer prefixes take priority cls._schemes = OrderedDict( sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) @classmethod def register_scheme(cls, prefixes, loader=None, force=False): """Register a loader to CheckpointLoader. This method can be used as a normal class method or a decorator. Args: prefixes (str or list[str] or tuple[str]): The prefix of the registered loader. loader (function, optional): The loader function to be registered. When this method is used as a decorator, loader is None. Defaults to None. force (bool, optional): Whether to override the loader if the prefix has already been registered. Defaults to False. """ if loader is not None: cls._register_scheme(prefixes, loader, force=force) return def _register(loader_cls): cls._register_scheme(prefixes, loader_cls, force=force) return loader_cls return _register @classmethod def _get_checkpoint_loader(cls, path): """Finds a loader that supports the given path. Falls back to the local loader if no other loader is found. Args: path (str): checkpoint path Returns: loader (function): checkpoint loader """ for p in cls._schemes: if path.startswith(p): return cls._schemes[p] @classmethod def load_checkpoint(cls, filename, map_location=None, logger=None): """load checkpoint through URL scheme path. Args: filename (str): checkpoint file name with given prefix map_location (str, optional): Same as :func:`torch.load`. Default: None logger (:mod:`logging.Logger`, optional): The logger for message. Default: None Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint_loader = cls._get_checkpoint_loader(filename) class_name = checkpoint_loader.__name__ mmcv.print_log( f'load checkpoint from {class_name[10:]} path: {filename}', logger) return checkpoint_loader(filename, map_location) def _load_checkpoint(filename, map_location=None, logger=None): """Load checkpoint from somewhere (modelzoo, file, url). Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str, optional): Same as :func:`torch.load`. Default: None. logger (:mod:`logging.Logger`, optional): The logger for error message. Default: None Returns: dict or OrderedDict: The loaded checkpoint. It can be either an OrderedDict storing model weights or a dict containing other information, which depends on the checkpoint. """ return CheckpointLoader.load_checkpoint(filename, map_location, logger) def load_checkpoint(model, filename, map_location=None, strict=False, logger=None, revise_keys=[(r'^module\.', '')]): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. revise_keys (list): A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module\\.', '')]. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location, logger) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint import pdb; pdb.set_trace() if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: state_dict = OrderedDict( {re.sub(p, r, k): v for k, v in state_dict.items()}) # Keep metadata in state_dict state_dict._metadata = metadata # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint