Unverified Commit 5bacc2f5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SAG] Support more schedulers, add better error message and make tests faster (#6465)



* finish

* finish

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6ae7e811
...@@ -681,6 +681,11 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin, ...@@ -681,6 +681,11 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin,
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]:
raise ValueError(
f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinlgestepScheduler'."
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
...@@ -830,14 +835,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin, ...@@ -830,14 +835,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin,
degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
# Noise it again to match the noise level # Noise it again to match the noise level
degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t[None])
return degraded_latents return degraded_latents
# Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
# Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.)
def pred_x0(self, sample, model_output, timestep): def pred_x0(self, sample, model_output, timestep):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
if self.scheduler.config.prediction_type == "epsilon": if self.scheduler.config.prediction_type == "epsilon":
......
...@@ -23,6 +23,9 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -23,6 +23,9 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
StableDiffusionSAGPipeline, StableDiffusionSAGPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -45,14 +48,15 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -45,14 +48,15 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(4, 8),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=8,
norm_num_groups=1,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=8,
) )
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -63,7 +67,8 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -63,7 +67,8 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
) )
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=[4, 8],
norm_num_groups=1,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
...@@ -74,11 +79,11 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -74,11 +79,11 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
hidden_size=32, hidden_size=8,
num_hidden_layers=2,
intermediate_size=37, intermediate_size=37,
layer_norm_eps=1e-05, layer_norm_eps=1e-05,
num_attention_heads=4, num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1, pad_token_id=1,
vocab_size=1000, vocab_size=1000,
) )
...@@ -108,13 +113,35 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -108,13 +113,35 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 1.0, "guidance_scale": 1.0,
"sag_scale": 1.0, "sag_scale": 1.0,
"output_type": "numpy", "output_type": "np",
} }
return inputs return inputs
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3) super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@unittest.skip("Not necessary to test here.")
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_pipeline_different_schedulers(self):
pipeline = self.pipeline_class(**self.get_dummy_components())
inputs = self.get_dummy_inputs("cpu")
expected_image_size = (16, 16, 3)
for scheduler_cls in [DDIMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler]:
pipeline.scheduler = scheduler_cls.from_config(pipeline.scheduler.config)
image = pipeline(**inputs).images[0]
shape = image.shape
assert shape == expected_image_size
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
with self.assertRaises(ValueError):
# Karras schedulers are not supported
image = pipeline(**inputs).images[0]
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment