Unverified Commit 4343ce2c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Harmonize single file ckpt model loading (#6971)

* use load_model_into_meta in single file utils

* propagate to autoencoder and controlnet.

* correct class name access behaviour.

* remove torch_dtype from load_model_into_meta; seems unncessary

* remove incorrect kwarg

* style to avoid extra unnecessary line breaks
parent 0ca7b681
...@@ -48,6 +48,7 @@ def build_sub_model_components( ...@@ -48,6 +48,7 @@ def build_sub_model_components(
load_safety_checker=False, load_safety_checker=False,
model_type=None, model_type=None,
image_size=None, image_size=None,
torch_dtype=None,
**kwargs, **kwargs,
): ):
if component_name in pipeline_components: if component_name in pipeline_components:
...@@ -96,7 +97,7 @@ def build_sub_model_components( ...@@ -96,7 +97,7 @@ def build_sub_model_components(
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
safety_checker = StableDiffusionSafetyChecker.from_pretrained( safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
) )
else: else:
safety_checker = None safety_checker = None
......
...@@ -48,7 +48,6 @@ if is_transformers_available(): ...@@ -48,7 +48,6 @@ if is_transformers_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -874,8 +873,17 @@ def create_diffusers_controlnet_model_from_ldm( ...@@ -874,8 +873,17 @@ def create_diffusers_controlnet_model_from_ldm(
controlnet = ControlNetModel(**diffusers_config) controlnet = ControlNetModel(**diffusers_config)
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in diffusers_format_controlnet_checkpoint.items(): from ..models.modeling_utils import load_model_dict_into_meta
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
unexpected_keys = load_model_dict_into_meta(controlnet, diffusers_format_controlnet_checkpoint)
if controlnet._keys_to_ignore_on_load_unexpected is not None:
for pat in controlnet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint) controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
...@@ -1038,8 +1046,17 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_ ...@@ -1038,8 +1046,17 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
text_model_dict[diffusers_key] = checkpoint[key] text_model_dict[diffusers_key] = checkpoint[key]
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in text_model_dict.items(): from ..models.modeling_utils import load_model_dict_into_meta
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None) text_model_dict.pop("text_model.embeddings.position_ids", None)
...@@ -1120,8 +1137,17 @@ def create_text_encoder_from_open_clip_checkpoint( ...@@ -1120,8 +1137,17 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model_dict[diffusers_key] = checkpoint[key] text_model_dict[diffusers_key] = checkpoint[key]
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in text_model_dict.items(): from ..models.modeling_utils import load_model_dict_into_meta
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
...@@ -1168,8 +1194,17 @@ def create_diffusers_unet_model_from_ldm( ...@@ -1168,8 +1194,17 @@ def create_diffusers_unet_model_from_ldm(
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in diffusers_format_unet_checkpoint.items(): from ..models.modeling_utils import load_model_dict_into_meta
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint)
if unet._keys_to_ignore_on_load_unexpected is not None:
for pat in unet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
unet.load_state_dict(diffusers_format_unet_checkpoint) unet.load_state_dict(diffusers_format_unet_checkpoint)
...@@ -1192,8 +1227,17 @@ def create_diffusers_vae_model_from_ldm( ...@@ -1192,8 +1227,17 @@ def create_diffusers_vae_model_from_ldm(
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in diffusers_format_vae_checkpoint.items(): from ..models.modeling_utils import load_model_dict_into_meta
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint)
if vae._keys_to_ignore_on_load_unexpected is not None:
for pat in vae._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: else:
vae.load_state_dict(diffusers_format_vae_checkpoint) vae.load_state_dict(diffusers_format_vae_checkpoint)
...@@ -1230,7 +1274,9 @@ def create_text_encoders_and_tokenizers_from_ldm( ...@@ -1230,7 +1274,9 @@ def create_text_encoders_and_tokenizers_from_ldm(
try: try:
config_name = "openai/clip-vit-large-patch14" config_name = "openai/clip-vit-large-patch14"
text_encoder = create_text_encoder_from_ldm_clip_checkpoint( text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only config_name,
checkpoint,
local_files_only=local_files_only,
) )
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment