Unverified Commit 0de35e4a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] Tighten up LoRA loading relaxation (#4787)



* debugging

* better logic for filtering.

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0d81e543
......@@ -1102,11 +1102,17 @@ class LoraLoaderMixin:
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment