Unverified Commit 5090b09d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Flux LoRA] support parsing alpha from a flux lora state dict. (#9236)

* support parsing alpha from a flux lora state dict.

* conditional import.

* fix breaking changes.

* safeguard alpha.

* fix
parent 32d6492c
...@@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -1583,7 +1583,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1583,7 +1583,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
return state_dict # For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas:
return state_dict, network_alphas
else:
return state_dict
def load_lora_weights( def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
...@@ -1617,7 +1636,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1617,7 +1636,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -1625,6 +1646,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1625,6 +1646,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
self.load_lora_into_transformer( self.load_lora_into_transformer(
state_dict, state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
...@@ -1634,7 +1656,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1634,7 +1656,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_state_dict, text_encoder_state_dict,
network_alphas=None, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
...@@ -1643,8 +1665,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1643,8 +1665,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -1653,6 +1674,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1653,6 +1674,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`SD3Transformer2DModel`): transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into. The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
...@@ -1684,7 +1709,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1684,7 +1709,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key: if "lora_B" in key:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
......
...@@ -12,19 +12,26 @@ ...@@ -12,19 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import tempfile
import unittest import unittest
import numpy as np
import safetensors.torch
import torch import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
if is_peft_available():
from peft.utils import get_peft_model_state_dict
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
...@@ -90,3 +97,51 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -90,3 +97,51 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_inputs.update({"generator": generator}) pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test.
if "transformer" in k and "to_k" in k and "lora_A" in k:
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment