Unverified Commit 8f2253c5 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add torch_xla and from_single_file to instruct-pix2pix (#10444)



* Add torch_xla and from_single_file to instruct-pix2pix

* StableDiffusionInstructPix2PixPipelineSingleFileSlowTests

* StableDiffusionInstructPix2PixPipelineSingleFileSlowTests

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 7747b588
...@@ -109,6 +109,7 @@ CHECKPOINT_KEY_NAMES = { ...@@ -109,6 +109,7 @@ CHECKPOINT_KEY_NAMES = {
"autoencoder-dc-sana": "encoder.project_in.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
} }
DIFFUSERS_DEFAULT_PIPELINE_PATHS = { DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
...@@ -165,6 +166,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { ...@@ -165,6 +166,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
} }
# Use to configure model sample size when original config is provided # Use to configure model sample size when original config is provided
...@@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint): ...@@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video" model_type = "hunyuan-video"
elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
):
model_type = "instruct-pix2pix"
else: else:
model_type = "v1" model_type = "v1"
......
...@@ -22,16 +22,23 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV ...@@ -22,16 +22,23 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
IPAdapterMixin, IPAdapterMixin,
FromSingleFileMixin,
): ):
r""" r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
...@@ -457,6 +465,9 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -457,6 +465,9 @@ class StableDiffusionInstructPix2PixPipeline(
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
...@@ -4,11 +4,13 @@ import unittest ...@@ -4,11 +4,13 @@ import unittest
import torch import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
nightly,
require_torch_accelerator, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
...@@ -118,3 +120,44 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi ...@@ -118,3 +120,44 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3) super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
@nightly
@slow
@require_torch_accelerator
class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = (
"https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
)
repo_id = "timbrooks/instruct-pix2pix"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"
)
inputs = {
"prompt": "turn him into a cyborg",
"image": image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"image_guidance_scale": 1.0,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment