Unverified Commit 8267c784 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Loading] Better error message on missing keys (#2198)

* up

* finish
parent 4fc70848
......@@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module):
param_device = "cpu"
state_dict = load_state_dict(model_file)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
......
......@@ -21,11 +21,20 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
from diffusers.models import ModelMixin
from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase):
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
class ModelTesterMixin:
def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment