Unverified Commit d1db4f85 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA ]fix flux lora loader when return_metadata is true for non-diffusers (#11716)

* fix flux lora loader when return_metadata is true for non-diffusers

* remove annotation
parent 8adc6003
...@@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if is_kohya: if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha. # Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_xlabs = any("processor" in k for k in state_dict) is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs: if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`. # xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_bfl_control = any("query_norm.scale" in k for k in state_dict) is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control: if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return (state_dict, None) if return_alphas else state_dict return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
# For state dicts like # For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
...@@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) )
if return_alphas or return_lora_metadata: if return_alphas or return_lora_metadata:
outputs = [state_dict] return cls._prepare_outputs(
if return_alphas: state_dict,
outputs.append(network_alphas) metadata=metadata,
if return_lora_metadata: alphas=network_alphas,
outputs.append(metadata) return_alphas=return_alphas,
return tuple(outputs) return_metadata=return_lora_metadata,
)
else: else:
return state_dict return state_dict
...@@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
......
...@@ -187,7 +187,9 @@ class PeftAdapterMixin: ...@@ -187,7 +187,9 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here: limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO metadata:
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
......
...@@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str): ...@@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str):
metadata = f.metadata() or {} metadata = f.metadata() or {}
metadata.pop("format", None) metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY) if metadata:
return json.loads(raw) if raw else None raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
else:
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment