Unverified Commit fc6a91e3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[FLUX] support LoRA (#9057)

* feat: lora support for Flux.

add tests

fix imports

major fixes.

* fix

fixes

final fixes?

* fix

* remove is_peft_available.
parent 2b760996
...@@ -66,6 +66,7 @@ if is_torch_available(): ...@@ -66,6 +66,7 @@ if is_torch_available():
"SD3LoraLoaderMixin", "SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin", "LoraLoaderMixin",
"FluxLoraLoaderMixin",
] ]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"]
...@@ -83,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -83,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import IPAdapterMixin from .ip_adapter import IPAdapterMixin
from .lora_pipeline import ( from .lora_pipeline import (
AmusedLoraLoaderMixin, AmusedLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin, LoraLoaderMixin,
SD3LoraLoaderMixin, SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
......
This diff is collapsed.
...@@ -32,6 +32,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -32,6 +32,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales, "UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights, "SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
} }
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -137,7 +137,7 @@ def retrieve_timesteps( ...@@ -137,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r""" r"""
The Flux pipeline for text-to-image generation. The Flux pipeline for text-to-image generation.
...@@ -321,7 +321,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): ...@@ -321,7 +321,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -354,12 +354,12 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): ...@@ -354,12 +354,12 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
) )
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
transformer_cls = FluxTransformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 1,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
has_two_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
...@@ -22,6 +22,7 @@ import torch.nn as nn ...@@ -22,6 +22,7 @@ import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
...@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4, "latent_channels": 4,
} }
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
import sys import sys
import unittest import unittest
from diffusers import ( from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
FlowMatchEulerDiscreteScheduler,
StableDiffusion3Pipeline, from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
)
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
...@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {} scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
...@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"pooled_projection_dim": 64, "pooled_projection_dim": 64,
"out_channels": 4, "out_channels": 4,
} }
transformer_cls = SD3Transformer2DModel
vae_kwargs = { vae_kwargs = {
"sample_size": 32, "sample_size": 32,
"in_channels": 3, "in_channels": 3,
...@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"scaling_factor": 1.5035, "scaling_factor": 1.5035,
} }
has_three_text_encoders = True has_three_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2"
text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
@property
def output_shape(self):
return (1, 32, 32, 3)
@require_torch_gpu @require_torch_gpu
def test_sd3_lora(self): def test_sd3_lora(self):
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import ( from diffusers import (
ControlNetModel, ControlNetModel,
...@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4, "latent_channels": 4,
"sample_size": 128, "sample_size": 128,
} }
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
@property
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment