Unverified Commit fc6a91e3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[FLUX] support LoRA (#9057)

* feat: lora support for Flux.

add tests

fix imports

major fixes.

* fix

fixes

final fixes?

* fix

* remove is_peft_available.
parent 2b760996
......@@ -66,6 +66,7 @@ if is_torch_available():
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
......@@ -83,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
......
This diff is collapsed.
......@@ -32,6 +32,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
}
......
......@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
......@@ -137,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
......@@ -321,7 +321,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
......@@ -354,12 +354,12 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
)
if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
transformer_cls = FluxTransformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 1,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
has_two_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
......@@ -22,6 +22,7 @@ import torch.nn as nn
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoPipelineForImage2Image,
......@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
......
......@@ -15,10 +15,9 @@
import sys
import unittest
from diffusers import (
FlowMatchEulerDiscreteScheduler,
StableDiffusion3Pipeline,
)
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
......@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
......@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"pooled_projection_dim": 64,
"out_channels": 4,
}
transformer_cls = SD3Transformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
......@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"scaling_factor": 1.5035,
}
has_three_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2"
text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 32, 32, 3)
@require_torch_gpu
def test_sd3_lora(self):
......
......@@ -22,6 +22,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
ControlNetModel,
......@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4,
"sample_size": 128,
}
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment