Unverified Commit f69511ec authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Single File Loading] Handle unexpected keys in CLIP models when `accelerate`...


[Single File Loading] Handle unexpected keys in CLIP models when `accelerate` isn't installed.  (#8462)

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d2b10b1f
......@@ -276,6 +276,10 @@ class FromOriginalModelMixin:
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
......@@ -284,8 +288,6 @@ class FromOriginalModelMixin:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
......
......@@ -1268,8 +1268,6 @@ def convert_open_clip_checkpoint(
else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
keys = list(checkpoint.keys())
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
......@@ -1318,9 +1316,6 @@ def convert_open_clip_checkpoint(
else:
text_model_dict[diffusers_key] = checkpoint.get(key)
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
return text_model_dict
......@@ -1414,6 +1409,9 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
......@@ -1423,9 +1421,6 @@ def create_diffusers_clip_model_from_ldm(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment