Unverified Commit 22ed39f5 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

Added Lora Support to SD3 Img2Img Pipeline (#9659)

* add lora
parent 56c21150
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -25,7 +25,7 @@ from transformers import ( ...@@ -25,7 +25,7 @@ from transformers import (
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -149,7 +149,7 @@ def retrieve_timesteps( ...@@ -149,7 +149,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
r""" r"""
Args: Args:
transformer ([`SD3Transformer2DModel`]): transformer ([`SD3Transformer2DModel`]):
...@@ -680,6 +680,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -680,6 +680,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property @property
def clip_skip(self): def clip_skip(self):
return self._clip_skip return self._clip_skip
...@@ -723,6 +727,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -723,6 +727,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
...@@ -797,6 +802,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -797,6 +802,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple. of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
...@@ -835,6 +844,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -835,6 +844,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
...@@ -847,6 +857,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -847,6 +857,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
device = self._execution_device device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
( (
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
...@@ -868,6 +882,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -868,6 +882,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
) )
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
...@@ -912,6 +927,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -912,6 +927,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
timestep=timestep, timestep=timestep,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -12,13 +12,29 @@ ...@@ -12,13 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import sys import sys
import unittest import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers import (
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
torch_device,
)
if is_peft_available(): if is_peft_available():
...@@ -29,6 +45,10 @@ sys.path.append(".") ...@@ -29,6 +45,10 @@ sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from utils import PeftLoraLoaderMixinTests # noqa: E402
if is_accelerate_available():
from accelerate.utils import release_memory
@require_peft_backend @require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
...@@ -108,3 +128,88 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -108,3 +128,88 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in SD3.") @unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@require_torch_gpu
@require_peft_backend
class LoraSD3IntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"prompt": "corgi",
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"generator": generator,
"image": init_image,
}
def test_sd3_img2img_lora(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
pipe.enable_sequential_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.47827148,
0.5,
0.71972656,
0.3955078,
0.4194336,
0.69628906,
0.37036133,
0.40820312,
0.6923828,
0.36450195,
0.40429688,
0.6904297,
0.35595703,
0.39257812,
0.68652344,
0.35498047,
0.3984375,
0.68310547,
0.34716797,
0.3996582,
0.6855469,
0.3388672,
0.3959961,
0.6816406,
0.34033203,
0.40429688,
0.6845703,
0.34228516,
0.4086914,
0.6870117,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
pipe.unload_lora_weights()
release_memory(pipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment