Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
779eef95
Unverified
Commit
779eef95
authored
Feb 18, 2024
by
YiYi Xu
Committed by
GitHub
Feb 19, 2024
Browse files
[from_single_file] pass `torch_dtype` to `set_module_tensor_to_device` (#6994)
fix Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
d5b8d1ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
14 deletions
+54
-14
src/diffusers/loaders/autoencoder.py
src/diffusers/loaders/autoencoder.py
+6
-1
src/diffusers/loaders/controlnet.py
src/diffusers/loaders/controlnet.py
+6
-1
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+9
-2
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+33
-10
No files found.
src/diffusers/loaders/autoencoder.py
View file @
779eef95
...
@@ -138,7 +138,12 @@ class FromOriginalVAEMixin:
...
@@ -138,7 +138,12 @@ class FromOriginalVAEMixin:
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
scaling_factor
=
kwargs
.
pop
(
"scaling_factor"
,
None
)
scaling_factor
=
kwargs
.
pop
(
"scaling_factor"
,
None
)
component
=
create_diffusers_vae_model_from_ldm
(
component
=
create_diffusers_vae_model_from_ldm
(
class_name
,
original_config
,
checkpoint
,
image_size
=
image_size
,
scaling_factor
=
scaling_factor
class_name
,
original_config
,
checkpoint
,
image_size
=
image_size
,
scaling_factor
=
scaling_factor
,
torch_dtype
=
torch_dtype
,
)
)
vae
=
component
[
"vae"
]
vae
=
component
[
"vae"
]
if
torch_dtype
is
not
None
:
if
torch_dtype
is
not
None
:
...
...
src/diffusers/loaders/controlnet.py
View file @
779eef95
...
@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin:
...
@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin:
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
component
=
create_diffusers_controlnet_model_from_ldm
(
component
=
create_diffusers_controlnet_model_from_ldm
(
class_name
,
original_config
,
checkpoint
,
upcast_attention
=
upcast_attention
,
image_size
=
image_size
class_name
,
original_config
,
checkpoint
,
upcast_attention
=
upcast_attention
,
image_size
=
image_size
,
torch_dtype
=
torch_dtype
,
)
)
controlnet
=
component
[
"controlnet"
]
controlnet
=
component
[
"controlnet"
]
if
torch_dtype
is
not
None
:
if
torch_dtype
is
not
None
:
...
...
src/diffusers/loaders/single_file.py
View file @
779eef95
...
@@ -57,14 +57,19 @@ def build_sub_model_components(
...
@@ -57,14 +57,19 @@ def build_sub_model_components(
if
component_name
==
"unet"
:
if
component_name
==
"unet"
:
num_in_channels
=
kwargs
.
pop
(
"num_in_channels"
,
None
)
num_in_channels
=
kwargs
.
pop
(
"num_in_channels"
,
None
)
unet_components
=
create_diffusers_unet_model_from_ldm
(
unet_components
=
create_diffusers_unet_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
num_in_channels
=
num_in_channels
,
image_size
=
image_size
pipeline_class_name
,
original_config
,
checkpoint
,
num_in_channels
=
num_in_channels
,
image_size
=
image_size
,
torch_dtype
=
torch_dtype
,
)
)
return
unet_components
return
unet_components
if
component_name
==
"vae"
:
if
component_name
==
"vae"
:
scaling_factor
=
kwargs
.
get
(
"scaling_factor"
,
None
)
scaling_factor
=
kwargs
.
get
(
"scaling_factor"
,
None
)
vae_components
=
create_diffusers_vae_model_from_ldm
(
vae_components
=
create_diffusers_vae_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
,
scaling_factor
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
,
scaling_factor
,
torch_dtype
)
)
return
vae_components
return
vae_components
...
@@ -89,6 +94,7 @@ def build_sub_model_components(
...
@@ -89,6 +94,7 @@ def build_sub_model_components(
checkpoint
,
checkpoint
,
model_type
=
model_type
,
model_type
=
model_type
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
)
)
return
text_encoder_components
return
text_encoder_components
...
@@ -261,6 +267,7 @@ class FromSingleFileMixin:
...
@@ -261,6 +267,7 @@ class FromSingleFileMixin:
image_size
=
image_size
,
image_size
=
image_size
,
load_safety_checker
=
load_safety_checker
,
load_safety_checker
=
load_safety_checker
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
kwargs
,
**
kwargs
,
)
)
if
not
components
:
if
not
components
:
...
...
src/diffusers/loaders/single_file_utils.py
View file @
779eef95
...
@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint(
...
@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint(
def
create_diffusers_controlnet_model_from_ldm
(
def
create_diffusers_controlnet_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
upcast_attention
=
False
,
image_size
=
None
pipeline_class_name
,
original_config
,
checkpoint
,
upcast_attention
=
False
,
image_size
=
None
,
torch_dtype
=
None
):
):
# import here to avoid circular imports
# import here to avoid circular imports
from
..models
import
ControlNetModel
from
..models
import
ControlNetModel
...
@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm(
...
@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
controlnet
,
diffusers_format_controlnet_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
controlnet
,
diffusers_format_controlnet_checkpoint
,
torch_dtype
=
torch_dtype
)
if
controlnet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
controlnet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
controlnet
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
controlnet
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm(
...
@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm(
else
:
else
:
controlnet
.
load_state_dict
(
diffusers_format_controlnet_checkpoint
)
controlnet
.
load_state_dict
(
diffusers_format_controlnet_checkpoint
)
if
torch_dtype
is
not
None
:
controlnet
=
controlnet
.
to
(
torch_dtype
)
return
{
"controlnet"
:
controlnet
}
return
{
"controlnet"
:
controlnet
}
...
@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
...
@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return
new_checkpoint
return
new_checkpoint
def
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
False
):
def
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
False
,
torch_dtype
=
None
):
try
:
try
:
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
except
Exception
:
except
Exception
:
...
@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
...
@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
)
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
,
dtype
=
torch_dtype
)
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
...
@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
text_model
.
load_state_dict
(
text_model_dict
)
text_model
.
load_state_dict
(
text_model_dict
)
if
torch_dtype
is
not
None
:
text_model
=
text_model
.
to
(
torch_dtype
)
return
text_model
return
text_model
...
@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint(
prefix
=
"cond_stage_model.model."
,
prefix
=
"cond_stage_model.model."
,
has_projection
=
False
,
has_projection
=
False
,
local_files_only
=
False
,
local_files_only
=
False
,
torch_dtype
=
None
,
**
config_kwargs
,
**
config_kwargs
,
):
):
try
:
try
:
...
@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
)
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
,
dtype
=
torch_dtype
)
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model
.
load_state_dict
(
text_model_dict
)
text_model
.
load_state_dict
(
text_model_dict
)
if
torch_dtype
is
not
None
:
text_model
=
text_model
.
to
(
torch_dtype
)
return
text_model
return
text_model
...
@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
upcast_attention
=
False
,
upcast_attention
=
False
,
extract_ema
=
False
,
extract_ema
=
False
,
image_size
=
None
,
image_size
=
None
,
torch_dtype
=
None
,
):
):
from
..models
import
UNet2DConditionModel
from
..models
import
UNet2DConditionModel
...
@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
unet
,
diffusers_format_unet_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
unet
,
diffusers_format_unet_checkpoint
,
dtype
=
torch_dtype
)
if
unet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
unet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
unet
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
unet
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
else
:
else
:
unet
.
load_state_dict
(
diffusers_format_unet_checkpoint
)
unet
.
load_state_dict
(
diffusers_format_unet_checkpoint
)
if
torch_dtype
is
not
None
:
unet
=
unet
.
to
(
torch_dtype
)
return
{
"unet"
:
unet
}
return
{
"unet"
:
unet
}
def
create_diffusers_vae_model_from_ldm
(
def
create_diffusers_vae_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
=
None
,
scaling_factor
=
None
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
=
None
,
scaling_factor
=
None
,
torch_dtype
=
None
):
):
# import here to avoid circular imports
# import here to avoid circular imports
from
..models
import
AutoencoderKL
from
..models
import
AutoencoderKL
...
@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm(
...
@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
vae
,
diffusers_format_vae_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
vae
,
diffusers_format_vae_checkpoint
,
dtype
=
torch_dtype
)
if
vae
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
vae
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
vae
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
vae
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm(
...
@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm(
else
:
else
:
vae
.
load_state_dict
(
diffusers_format_vae_checkpoint
)
vae
.
load_state_dict
(
diffusers_format_vae_checkpoint
)
if
torch_dtype
is
not
None
:
vae
=
vae
.
to
(
torch_dtype
)
return
{
"vae"
:
vae
}
return
{
"vae"
:
vae
}
...
@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
checkpoint
,
checkpoint
,
model_type
=
None
,
model_type
=
None
,
local_files_only
=
False
,
local_files_only
=
False
,
torch_dtype
=
None
,
):
):
model_type
=
infer_model_type
(
original_config
,
model_type
=
model_type
)
model_type
=
infer_model_type
(
original_config
,
model_type
=
model_type
)
...
@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
try
:
try
:
text_encoder
=
create_text_encoder_from_open_clip_checkpoint
(
text_encoder
=
create_text_encoder_from_open_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
**
config_kwargs
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
)
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
subfolder
=
"tokenizer"
,
local_files_only
=
local_files_only
config_name
,
subfolder
=
"tokenizer"
,
local_files_only
=
local_files_only
...
@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name
,
config_name
,
checkpoint
,
checkpoint
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
)
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
...
@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix
=
prefix
,
prefix
=
prefix
,
has_projection
=
True
,
has_projection
=
True
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
,
**
config_kwargs
,
)
)
except
Exception
:
except
Exception
:
...
@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name
=
"openai/clip-vit-large-patch14"
config_name
=
"openai/clip-vit-large-patch14"
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
text_encoder
=
create_text_encoder_from_ldm_clip_checkpoint
(
text_encoder
=
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
local_files_only
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
)
)
except
Exception
:
except
Exception
:
...
@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix
=
prefix
,
prefix
=
prefix
,
has_projection
=
True
,
has_projection
=
True
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
,
**
config_kwargs
,
)
)
except
Exception
:
except
Exception
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment