Unverified Commit 52222947 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] relax lora loading logic (#4610)



* relax lora loading logic.

* cater to the other cases too.

* fix: variable name

* bring the chaos down.

* check

* deal with checkpointed files.

* Apply suggestions from code review
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>

* style

---------
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>
parent c25c4613
...@@ -25,7 +25,7 @@ import requests ...@@ -25,7 +25,7 @@ import requests
import safetensors import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, model_info
from torch import nn from torch import nn
from .utils import ( from .utils import (
...@@ -1021,6 +1021,13 @@ class LoraLoaderMixin: ...@@ -1021,6 +1021,13 @@ class LoraLoaderMixin:
weight_name is not None and weight_name.endswith(".safetensors") weight_name is not None and weight_name.endswith(".safetensors")
): ):
try: try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
)
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
...@@ -1041,7 +1048,12 @@ class LoraLoaderMixin: ...@@ -1041,7 +1048,12 @@ class LoraLoaderMixin:
# try loading non-safetensors weights # try loading non-safetensors weights
model_file = None model_file = None
pass pass
if model_file is None: if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin"
)
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME, weights_name=weight_name or LORA_WEIGHT_NAME,
...@@ -1077,6 +1089,31 @@ class LoraLoaderMixin: ...@@ -1077,6 +1089,31 @@ class LoraLoaderMixin:
return state_dict, network_alphas return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
@classmethod @classmethod
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict) is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment