Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
779eef95
Unverified
Commit
779eef95
authored
Feb 18, 2024
by
YiYi Xu
Committed by
GitHub
Feb 19, 2024
Browse files
[from_single_file] pass `torch_dtype` to `set_module_tensor_to_device` (#6994)
fix Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
d5b8d1ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
14 deletions
+54
-14
src/diffusers/loaders/autoencoder.py
src/diffusers/loaders/autoencoder.py
+6
-1
src/diffusers/loaders/controlnet.py
src/diffusers/loaders/controlnet.py
+6
-1
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+9
-2
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+33
-10
No files found.
src/diffusers/loaders/autoencoder.py
View file @
779eef95
...
@@ -138,7 +138,12 @@ class FromOriginalVAEMixin:
...
@@ -138,7 +138,12 @@ class FromOriginalVAEMixin:
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
scaling_factor
=
kwargs
.
pop
(
"scaling_factor"
,
None
)
scaling_factor
=
kwargs
.
pop
(
"scaling_factor"
,
None
)
component
=
create_diffusers_vae_model_from_ldm
(
component
=
create_diffusers_vae_model_from_ldm
(
class_name
,
original_config
,
checkpoint
,
image_size
=
image_size
,
scaling_factor
=
scaling_factor
class_name
,
original_config
,
checkpoint
,
image_size
=
image_size
,
scaling_factor
=
scaling_factor
,
torch_dtype
=
torch_dtype
,
)
)
vae
=
component
[
"vae"
]
vae
=
component
[
"vae"
]
if
torch_dtype
is
not
None
:
if
torch_dtype
is
not
None
:
...
...
src/diffusers/loaders/controlnet.py
View file @
779eef95
...
@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin:
...
@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin:
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
image_size
=
kwargs
.
pop
(
"image_size"
,
None
)
component
=
create_diffusers_controlnet_model_from_ldm
(
component
=
create_diffusers_controlnet_model_from_ldm
(
class_name
,
original_config
,
checkpoint
,
upcast_attention
=
upcast_attention
,
image_size
=
image_size
class_name
,
original_config
,
checkpoint
,
upcast_attention
=
upcast_attention
,
image_size
=
image_size
,
torch_dtype
=
torch_dtype
,
)
)
controlnet
=
component
[
"controlnet"
]
controlnet
=
component
[
"controlnet"
]
if
torch_dtype
is
not
None
:
if
torch_dtype
is
not
None
:
...
...
src/diffusers/loaders/single_file.py
View file @
779eef95
...
@@ -57,14 +57,19 @@ def build_sub_model_components(
...
@@ -57,14 +57,19 @@ def build_sub_model_components(
if
component_name
==
"unet"
:
if
component_name
==
"unet"
:
num_in_channels
=
kwargs
.
pop
(
"num_in_channels"
,
None
)
num_in_channels
=
kwargs
.
pop
(
"num_in_channels"
,
None
)
unet_components
=
create_diffusers_unet_model_from_ldm
(
unet_components
=
create_diffusers_unet_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
num_in_channels
=
num_in_channels
,
image_size
=
image_size
pipeline_class_name
,
original_config
,
checkpoint
,
num_in_channels
=
num_in_channels
,
image_size
=
image_size
,
torch_dtype
=
torch_dtype
,
)
)
return
unet_components
return
unet_components
if
component_name
==
"vae"
:
if
component_name
==
"vae"
:
scaling_factor
=
kwargs
.
get
(
"scaling_factor"
,
None
)
scaling_factor
=
kwargs
.
get
(
"scaling_factor"
,
None
)
vae_components
=
create_diffusers_vae_model_from_ldm
(
vae_components
=
create_diffusers_vae_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
,
scaling_factor
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
,
scaling_factor
,
torch_dtype
)
)
return
vae_components
return
vae_components
...
@@ -89,6 +94,7 @@ def build_sub_model_components(
...
@@ -89,6 +94,7 @@ def build_sub_model_components(
checkpoint
,
checkpoint
,
model_type
=
model_type
,
model_type
=
model_type
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
)
)
return
text_encoder_components
return
text_encoder_components
...
@@ -261,6 +267,7 @@ class FromSingleFileMixin:
...
@@ -261,6 +267,7 @@ class FromSingleFileMixin:
image_size
=
image_size
,
image_size
=
image_size
,
load_safety_checker
=
load_safety_checker
,
load_safety_checker
=
load_safety_checker
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
kwargs
,
**
kwargs
,
)
)
if
not
components
:
if
not
components
:
...
...
src/diffusers/loaders/single_file_utils.py
View file @
779eef95
...
@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint(
...
@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint(
def
create_diffusers_controlnet_model_from_ldm
(
def
create_diffusers_controlnet_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
upcast_attention
=
False
,
image_size
=
None
pipeline_class_name
,
original_config
,
checkpoint
,
upcast_attention
=
False
,
image_size
=
None
,
torch_dtype
=
None
):
):
# import here to avoid circular imports
# import here to avoid circular imports
from
..models
import
ControlNetModel
from
..models
import
ControlNetModel
...
@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm(
...
@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
controlnet
,
diffusers_format_controlnet_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
controlnet
,
diffusers_format_controlnet_checkpoint
,
torch_dtype
=
torch_dtype
)
if
controlnet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
controlnet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
controlnet
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
controlnet
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm(
...
@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm(
else
:
else
:
controlnet
.
load_state_dict
(
diffusers_format_controlnet_checkpoint
)
controlnet
.
load_state_dict
(
diffusers_format_controlnet_checkpoint
)
if
torch_dtype
is
not
None
:
controlnet
=
controlnet
.
to
(
torch_dtype
)
return
{
"controlnet"
:
controlnet
}
return
{
"controlnet"
:
controlnet
}
...
@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
...
@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return
new_checkpoint
return
new_checkpoint
def
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
False
):
def
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
False
,
torch_dtype
=
None
):
try
:
try
:
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
except
Exception
:
except
Exception
:
...
@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
...
@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
)
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
,
dtype
=
torch_dtype
)
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
...
@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
text_model
.
load_state_dict
(
text_model_dict
)
text_model
.
load_state_dict
(
text_model_dict
)
if
torch_dtype
is
not
None
:
text_model
=
text_model
.
to
(
torch_dtype
)
return
text_model
return
text_model
...
@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint(
prefix
=
"cond_stage_model.model."
,
prefix
=
"cond_stage_model.model."
,
has_projection
=
False
,
has_projection
=
False
,
local_files_only
=
False
,
local_files_only
=
False
,
torch_dtype
=
None
,
**
config_kwargs
,
**
config_kwargs
,
):
):
try
:
try
:
...
@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
)
unexpected_keys
=
load_model_dict_into_meta
(
text_model
,
text_model_dict
,
dtype
=
torch_dtype
)
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
text_model
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
text_model
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint(
...
@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model
.
load_state_dict
(
text_model_dict
)
text_model
.
load_state_dict
(
text_model_dict
)
if
torch_dtype
is
not
None
:
text_model
=
text_model
.
to
(
torch_dtype
)
return
text_model
return
text_model
...
@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
upcast_attention
=
False
,
upcast_attention
=
False
,
extract_ema
=
False
,
extract_ema
=
False
,
image_size
=
None
,
image_size
=
None
,
torch_dtype
=
None
,
):
):
from
..models
import
UNet2DConditionModel
from
..models
import
UNet2DConditionModel
...
@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
unet
,
diffusers_format_unet_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
unet
,
diffusers_format_unet_checkpoint
,
dtype
=
torch_dtype
)
if
unet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
unet
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
unet
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
unet
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
...
@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
else
:
else
:
unet
.
load_state_dict
(
diffusers_format_unet_checkpoint
)
unet
.
load_state_dict
(
diffusers_format_unet_checkpoint
)
if
torch_dtype
is
not
None
:
unet
=
unet
.
to
(
torch_dtype
)
return
{
"unet"
:
unet
}
return
{
"unet"
:
unet
}
def
create_diffusers_vae_model_from_ldm
(
def
create_diffusers_vae_model_from_ldm
(
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
=
None
,
scaling_factor
=
None
pipeline_class_name
,
original_config
,
checkpoint
,
image_size
=
None
,
scaling_factor
=
None
,
torch_dtype
=
None
):
):
# import here to avoid circular imports
# import here to avoid circular imports
from
..models
import
AutoencoderKL
from
..models
import
AutoencoderKL
...
@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm(
...
@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm(
if
is_accelerate_available
():
if
is_accelerate_available
():
from
..models.modeling_utils
import
load_model_dict_into_meta
from
..models.modeling_utils
import
load_model_dict_into_meta
unexpected_keys
=
load_model_dict_into_meta
(
vae
,
diffusers_format_vae_checkpoint
)
unexpected_keys
=
load_model_dict_into_meta
(
vae
,
diffusers_format_vae_checkpoint
,
dtype
=
torch_dtype
)
if
vae
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
if
vae
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
for
pat
in
vae
.
_keys_to_ignore_on_load_unexpected
:
for
pat
in
vae
.
_keys_to_ignore_on_load_unexpected
:
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
unexpected_keys
=
[
k
for
k
in
unexpected_keys
if
re
.
search
(
pat
,
k
)
is
None
]
...
@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm(
...
@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm(
else
:
else
:
vae
.
load_state_dict
(
diffusers_format_vae_checkpoint
)
vae
.
load_state_dict
(
diffusers_format_vae_checkpoint
)
if
torch_dtype
is
not
None
:
vae
=
vae
.
to
(
torch_dtype
)
return
{
"vae"
:
vae
}
return
{
"vae"
:
vae
}
...
@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
checkpoint
,
checkpoint
,
model_type
=
None
,
model_type
=
None
,
local_files_only
=
False
,
local_files_only
=
False
,
torch_dtype
=
None
,
):
):
model_type
=
infer_model_type
(
original_config
,
model_type
=
model_type
)
model_type
=
infer_model_type
(
original_config
,
model_type
=
model_type
)
...
@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
try
:
try
:
text_encoder
=
create_text_encoder_from_open_clip_checkpoint
(
text_encoder
=
create_text_encoder_from_open_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
**
config_kwargs
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
)
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
subfolder
=
"tokenizer"
,
local_files_only
=
local_files_only
config_name
,
subfolder
=
"tokenizer"
,
local_files_only
=
local_files_only
...
@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name
,
config_name
,
checkpoint
,
checkpoint
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
)
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
...
@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix
=
prefix
,
prefix
=
prefix
,
has_projection
=
True
,
has_projection
=
True
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
,
**
config_kwargs
,
)
)
except
Exception
:
except
Exception
:
...
@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name
=
"openai/clip-vit-large-patch14"
config_name
=
"openai/clip-vit-large-patch14"
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
config_name
,
local_files_only
=
local_files_only
)
text_encoder
=
create_text_encoder_from_ldm_clip_checkpoint
(
text_encoder
=
create_text_encoder_from_ldm_clip_checkpoint
(
config_name
,
checkpoint
,
local_files_only
=
local_files_only
config_name
,
checkpoint
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
)
)
except
Exception
:
except
Exception
:
...
@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
...
@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix
=
prefix
,
prefix
=
prefix
,
has_projection
=
True
,
has_projection
=
True
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
torch_dtype
=
torch_dtype
,
**
config_kwargs
,
**
config_kwargs
,
)
)
except
Exception
:
except
Exception
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment