Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8267c784
Unverified
Commit
8267c784
authored
Feb 01, 2023
by
Patrick von Platen
Committed by
GitHub
Feb 01, 2023
Browse files
[Loading] Better error message on missing keys (#2198)
* up * finish
parent
4fc70848
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
1 deletion
+19
-1
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+9
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+10
-1
No files found.
src/diffusers/models/modeling_utils.py
View file @
8267c784
...
@@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module):
...
@@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module):
param_device
=
"cpu"
param_device
=
"cpu"
state_dict
=
load_state_dict
(
model_file
)
state_dict
=
load_state_dict
(
model_file
)
# move the params from meta device to cpu
# move the params from meta device to cpu
missing_keys
=
set
(
model
.
state_dict
().
keys
())
-
set
(
state_dict
.
keys
())
if
len
(
missing_keys
)
>
0
:
raise
ValueError
(
f
"Cannot load
{
cls
}
from
{
pretrained_model_name_or_path
}
because the following keys are"
f
" missing:
\n
{
', '
.
join
(
missing_keys
)
}
.
\n
Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for
param_name
,
param
in
state_dict
.
items
():
for
param_name
,
param
in
state_dict
.
items
():
accepts_dtype
=
"dtype"
in
set
(
accepts_dtype
=
"dtype"
in
set
(
inspect
.
signature
(
set_module_tensor_to_device
).
parameters
.
keys
()
inspect
.
signature
(
set_module_tensor_to_device
).
parameters
.
keys
()
...
...
tests/test_modeling_common.py
View file @
8267c784
...
@@ -21,11 +21,20 @@ from typing import Dict, List, Tuple
...
@@ -21,11 +21,20 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers.models
import
ModelMixin
from
diffusers.models
import
ModelMixin
,
UNet2DConditionModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
torch_device
from
diffusers.utils
import
torch_device
class
ModelUtilsTest
(
unittest
.
TestCase
):
def
test_accelerate_loading_error_message
(
self
):
with
self
.
assertRaises
(
ValueError
)
as
error_context
:
UNet2DConditionModel
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken"
,
subfolder
=
"unet"
)
# make sure that error message states what keys are missing
assert
"conv_out.bias"
in
str
(
error_context
.
exception
)
class
ModelTesterMixin
:
class
ModelTesterMixin
:
def
test_from_save_pretrained
(
self
):
def
test_from_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment