Unverified Commit 779eef95 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[from_single_file] pass `torch_dtype` to `set_module_tensor_to_device` (#6994)



fix
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent d5b8d1ca
...@@ -138,7 +138,12 @@ class FromOriginalVAEMixin: ...@@ -138,7 +138,12 @@ class FromOriginalVAEMixin:
image_size = kwargs.pop("image_size", None) image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None) scaling_factor = kwargs.pop("scaling_factor", None)
component = create_diffusers_vae_model_from_ldm( component = create_diffusers_vae_model_from_ldm(
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor class_name,
original_config,
checkpoint,
image_size=image_size,
scaling_factor=scaling_factor,
torch_dtype=torch_dtype,
) )
vae = component["vae"] vae = component["vae"]
if torch_dtype is not None: if torch_dtype is not None:
......
...@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin: ...@@ -128,7 +128,12 @@ class FromOriginalControlNetMixin:
image_size = kwargs.pop("image_size", None) image_size = kwargs.pop("image_size", None)
component = create_diffusers_controlnet_model_from_ldm( component = create_diffusers_controlnet_model_from_ldm(
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size class_name,
original_config,
checkpoint,
upcast_attention=upcast_attention,
image_size=image_size,
torch_dtype=torch_dtype,
) )
controlnet = component["controlnet"] controlnet = component["controlnet"]
if torch_dtype is not None: if torch_dtype is not None:
......
...@@ -57,14 +57,19 @@ def build_sub_model_components( ...@@ -57,14 +57,19 @@ def build_sub_model_components(
if component_name == "unet": if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None) num_in_channels = kwargs.pop("num_in_channels", None)
unet_components = create_diffusers_unet_model_from_ldm( unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size pipeline_class_name,
original_config,
checkpoint,
num_in_channels=num_in_channels,
image_size=image_size,
torch_dtype=torch_dtype,
) )
return unet_components return unet_components
if component_name == "vae": if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None) scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm( vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
) )
return vae_components return vae_components
...@@ -89,6 +94,7 @@ def build_sub_model_components( ...@@ -89,6 +94,7 @@ def build_sub_model_components(
checkpoint, checkpoint,
model_type=model_type, model_type=model_type,
local_files_only=local_files_only, local_files_only=local_files_only,
torch_dtype=torch_dtype,
) )
return text_encoder_components return text_encoder_components
...@@ -261,6 +267,7 @@ class FromSingleFileMixin: ...@@ -261,6 +267,7 @@ class FromSingleFileMixin:
image_size=image_size, image_size=image_size,
load_safety_checker=load_safety_checker, load_safety_checker=load_safety_checker,
local_files_only=local_files_only, local_files_only=local_files_only,
torch_dtype=torch_dtype,
**kwargs, **kwargs,
) )
if not components: if not components:
......
...@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint( ...@@ -856,7 +856,7 @@ def convert_controlnet_checkpoint(
def create_diffusers_controlnet_model_from_ldm( def create_diffusers_controlnet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
): ):
# import here to avoid circular imports # import here to avoid circular imports
from ..models import ControlNetModel from ..models import ControlNetModel
...@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm( ...@@ -875,7 +875,9 @@ def create_diffusers_controlnet_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(controlnet, diffusers_format_controlnet_checkpoint) unexpected_keys = load_model_dict_into_meta(
controlnet, diffusers_format_controlnet_checkpoint, torch_dtype=torch_dtype
)
if controlnet._keys_to_ignore_on_load_unexpected is not None: if controlnet._keys_to_ignore_on_load_unexpected is not None:
for pat in controlnet._keys_to_ignore_on_load_unexpected: for pat in controlnet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
...@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm( ...@@ -887,6 +889,9 @@ def create_diffusers_controlnet_model_from_ldm(
else: else:
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint) controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
return {"controlnet": controlnet} return {"controlnet": controlnet}
...@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): ...@@ -1022,7 +1027,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint return new_checkpoint
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False): def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, torch_dtype=None):
try: try:
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
except Exception: except Exception:
...@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_ ...@@ -1048,7 +1053,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
if is_accelerate_available(): if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict) unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
if text_model._keys_to_ignore_on_load_unexpected is not None: if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected: for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
...@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_ ...@@ -1063,6 +1068,9 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
if torch_dtype is not None:
text_model = text_model.to(torch_dtype)
return text_model return text_model
...@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint( ...@@ -1072,6 +1080,7 @@ def create_text_encoder_from_open_clip_checkpoint(
prefix="cond_stage_model.model.", prefix="cond_stage_model.model.",
has_projection=False, has_projection=False,
local_files_only=False, local_files_only=False,
torch_dtype=None,
**config_kwargs, **config_kwargs,
): ):
try: try:
...@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint( ...@@ -1139,7 +1148,7 @@ def create_text_encoder_from_open_clip_checkpoint(
if is_accelerate_available(): if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict) unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
if text_model._keys_to_ignore_on_load_unexpected is not None: if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected: for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
...@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint( ...@@ -1155,6 +1164,9 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
if torch_dtype is not None:
text_model = text_model.to(torch_dtype)
return text_model return text_model
...@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm( ...@@ -1166,6 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
upcast_attention=False, upcast_attention=False,
extract_ema=False, extract_ema=False,
image_size=None, image_size=None,
torch_dtype=None,
): ):
from ..models import UNet2DConditionModel from ..models import UNet2DConditionModel
...@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm( ...@@ -1198,7 +1211,7 @@ def create_diffusers_unet_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint) unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
if unet._keys_to_ignore_on_load_unexpected is not None: if unet._keys_to_ignore_on_load_unexpected is not None:
for pat in unet._keys_to_ignore_on_load_unexpected: for pat in unet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
...@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm( ...@@ -1210,11 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
else: else:
unet.load_state_dict(diffusers_format_unet_checkpoint) unet.load_state_dict(diffusers_format_unet_checkpoint)
if torch_dtype is not None:
unet = unet.to(torch_dtype)
return {"unet": unet} return {"unet": unet}
def create_diffusers_vae_model_from_ldm( def create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
): ):
# import here to avoid circular imports # import here to avoid circular imports
from ..models import AutoencoderKL from ..models import AutoencoderKL
...@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm( ...@@ -1231,7 +1247,7 @@ def create_diffusers_vae_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint) unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
if vae._keys_to_ignore_on_load_unexpected is not None: if vae._keys_to_ignore_on_load_unexpected is not None:
for pat in vae._keys_to_ignore_on_load_unexpected: for pat in vae._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
...@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm( ...@@ -1243,6 +1259,9 @@ def create_diffusers_vae_model_from_ldm(
else: else:
vae.load_state_dict(diffusers_format_vae_checkpoint) vae.load_state_dict(diffusers_format_vae_checkpoint)
if torch_dtype is not None:
vae = vae.to(torch_dtype)
return {"vae": vae} return {"vae": vae}
...@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1251,6 +1270,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
checkpoint, checkpoint,
model_type=None, model_type=None,
local_files_only=False, local_files_only=False,
torch_dtype=None,
): ):
model_type = infer_model_type(original_config, model_type=model_type) model_type = infer_model_type(original_config, model_type=model_type)
...@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1260,7 +1280,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
try: try:
text_encoder = create_text_encoder_from_open_clip_checkpoint( text_encoder = create_text_encoder_from_open_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only, **config_kwargs config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
) )
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
config_name, subfolder="tokenizer", local_files_only=local_files_only config_name, subfolder="tokenizer", local_files_only=local_files_only
...@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1279,6 +1299,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name, config_name,
checkpoint, checkpoint,
local_files_only=local_files_only, local_files_only=local_files_only,
torch_dtype=torch_dtype,
) )
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
...@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1302,6 +1323,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix=prefix, prefix=prefix,
has_projection=True, has_projection=True,
local_files_only=local_files_only, local_files_only=local_files_only,
torch_dtype=torch_dtype,
**config_kwargs, **config_kwargs,
) )
except Exception: except Exception:
...@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1322,7 +1344,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
config_name = "openai/clip-vit-large-patch14" config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint( text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
) )
except Exception: except Exception:
...@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1341,6 +1363,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
prefix=prefix, prefix=prefix,
has_projection=True, has_projection=True,
local_files_only=local_files_only, local_files_only=local_files_only,
torch_dtype=torch_dtype,
**config_kwargs, **config_kwargs,
) )
except Exception: except Exception:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment