Unverified Commit f523b11a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix loading if unexpected keys are present (#3720)

* Fix loading

* make style
parent 79fa94ea
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import inspect import inspect
import itertools import itertools
import os import os
import re
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
...@@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module): ...@@ -162,6 +163,7 @@ class ModelMixin(torch.nn.Module):
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -608,6 +610,7 @@ class ModelMixin(torch.nn.Module): ...@@ -608,6 +610,7 @@ class ModelMixin(torch.nn.Module):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
unexpected_keys = []
empty_state_dict = model.state_dict() empty_state_dict = model.state_dict()
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
...@@ -615,6 +618,10 @@ class ModelMixin(torch.nn.Module): ...@@ -615,6 +618,10 @@ class ModelMixin(torch.nn.Module):
inspect.signature(set_module_tensor_to_device).parameters.keys() inspect.signature(set_module_tensor_to_device).parameters.keys()
) )
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape: if empty_state_dict[param_name].shape != param.shape:
raise ValueError( raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
...@@ -626,6 +633,16 @@ class ModelMixin(torch.nn.Module): ...@@ -626,6 +633,16 @@ class ModelMixin(torch.nn.Module):
) )
else: else:
set_module_tensor_to_device(model, param_name, param_device, value=param) set_module_tensor_to_device(model, param_name, param_device, value=param)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU # by default the device_map is None and the weights are loaded on the CPU
......
...@@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): ...@@ -61,6 +61,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
dot-product/softmax to float() when training with mixed precision. dot-product/softmax to float() when training with mixed precision.
""" """
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment