Unverified Commit 22ed39f5 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

Added Lora Support to SD3 Img2Img Pipeline (#9659)

* add lora
parent 56c21150
......@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
......@@ -25,7 +25,7 @@ from transformers import (
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
......@@ -149,7 +149,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
......@@ -680,6 +680,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def clip_skip(self):
return self._clip_skip
......@@ -723,6 +727,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
......@@ -797,6 +802,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
......@@ -835,6 +844,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
......@@ -847,6 +857,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
......@@ -868,6 +882,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
......@@ -912,6 +927,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
......
......@@ -12,13 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
from diffusers import (
FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
torch_device,
)
if is_peft_available():
......@@ -29,6 +45,10 @@ sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
if is_accelerate_available():
from accelerate.utils import release_memory
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
......@@ -108,3 +128,88 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
@require_torch_gpu
@require_peft_backend
class LoraSD3IntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"prompt": "corgi",
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"generator": generator,
"image": init_image,
}
def test_sd3_img2img_lora(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
pipe.enable_sequential_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.47827148,
0.5,
0.71972656,
0.3955078,
0.4194336,
0.69628906,
0.37036133,
0.40820312,
0.6923828,
0.36450195,
0.40429688,
0.6904297,
0.35595703,
0.39257812,
0.68652344,
0.35498047,
0.3984375,
0.68310547,
0.34716797,
0.3996582,
0.6855469,
0.3388672,
0.3959961,
0.6816406,
0.34033203,
0.40429688,
0.6845703,
0.34228516,
0.4086914,
0.6870117,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
pipe.unload_lora_weights()
release_memory(pipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment