Unverified Commit 64cbd8e2 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Support LCM in ControlNet and Adapter pipelines. (#5822)

* support lcm

* fix tests

* fix tests
parent 038b42db
...@@ -726,6 +726,46 @@ class StableDiffusionControlNetPipeline( ...@@ -726,6 +726,46 @@ class StableDiffusionControlNetPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -863,6 +903,8 @@ class StableDiffusionControlNetPipeline( ...@@ -863,6 +903,8 @@ class StableDiffusionControlNetPipeline(
control_guidance_end, control_guidance_end,
) )
self._guidance_scale = guidance_scale
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -872,10 +914,6 @@ class StableDiffusionControlNetPipeline( ...@@ -872,10 +914,6 @@ class StableDiffusionControlNetPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -895,7 +933,7 @@ class StableDiffusionControlNetPipeline( ...@@ -895,7 +933,7 @@ class StableDiffusionControlNetPipeline(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
...@@ -905,7 +943,7 @@ class StableDiffusionControlNetPipeline( ...@@ -905,7 +943,7 @@ class StableDiffusionControlNetPipeline(
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image # 4. Prepare image
...@@ -918,7 +956,7 @@ class StableDiffusionControlNetPipeline( ...@@ -918,7 +956,7 @@ class StableDiffusionControlNetPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = image.shape[-2:] height, width = image.shape[-2:]
...@@ -934,7 +972,7 @@ class StableDiffusionControlNetPipeline( ...@@ -934,7 +972,7 @@ class StableDiffusionControlNetPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -962,6 +1000,14 @@ class StableDiffusionControlNetPipeline( ...@@ -962,6 +1000,14 @@ class StableDiffusionControlNetPipeline(
latents, latents,
) )
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
...@@ -986,11 +1032,11 @@ class StableDiffusionControlNetPipeline( ...@@ -986,11 +1032,11 @@ class StableDiffusionControlNetPipeline(
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1017,7 +1063,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1017,7 +1063,7 @@ class StableDiffusionControlNetPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1029,6 +1075,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1029,6 +1075,7 @@ class StableDiffusionControlNetPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
...@@ -1036,7 +1083,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1036,7 +1083,7 @@ class StableDiffusionControlNetPipeline(
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
......
...@@ -791,6 +791,46 @@ class StableDiffusionXLControlNetPipeline( ...@@ -791,6 +791,46 @@ class StableDiffusionXLControlNetPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -986,6 +1026,8 @@ class StableDiffusionXLControlNetPipeline( ...@@ -986,6 +1026,8 @@ class StableDiffusionXLControlNetPipeline(
control_guidance_end, control_guidance_end,
) )
self._guidance_scale = guidance_scale
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -995,10 +1037,6 @@ class StableDiffusionXLControlNetPipeline( ...@@ -995,10 +1037,6 @@ class StableDiffusionXLControlNetPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -1024,7 +1062,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1024,7 +1062,7 @@ class StableDiffusionXLControlNetPipeline(
prompt_2, prompt_2,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
negative_prompt_2, negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -1045,7 +1083,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1045,7 +1083,7 @@ class StableDiffusionXLControlNetPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = image.shape[-2:] height, width = image.shape[-2:]
...@@ -1061,7 +1099,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1061,7 +1099,7 @@ class StableDiffusionXLControlNetPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -1089,6 +1127,14 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1089,6 +1127,14 @@ class StableDiffusionXLControlNetPipeline(
latents, latents,
) )
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
...@@ -1133,7 +1179,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1133,7 +1179,7 @@ class StableDiffusionXLControlNetPipeline(
else: else:
negative_add_time_ids = add_time_ids negative_add_time_ids = add_time_ids
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
...@@ -1154,13 +1200,13 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1154,13 +1200,13 @@ class StableDiffusionXLControlNetPipeline(
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1193,7 +1239,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1193,7 +1239,7 @@ class StableDiffusionXLControlNetPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1205,6 +1251,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1205,6 +1251,7 @@ class StableDiffusionXLControlNetPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
...@@ -1213,7 +1260,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1213,7 +1260,7 @@ class StableDiffusionXLControlNetPipeline(
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
......
...@@ -610,6 +610,46 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -610,6 +610,46 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -723,6 +763,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -723,6 +763,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
) )
self._guidance_scale = guidance_scale
if isinstance(self.adapter, MultiAdapter): if isinstance(self.adapter, MultiAdapter):
adapter_input = [] adapter_input = []
...@@ -742,17 +784,12 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -742,17 +784,12 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
...@@ -761,7 +798,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -761,7 +798,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -784,6 +821,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -784,6 +821,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Denoising loop # 7. Denoising loop
if isinstance(self.adapter, MultiAdapter): if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
...@@ -796,7 +841,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -796,7 +841,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
if num_images_per_prompt > 1: if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state): for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
for k, v in enumerate(adapter_state): for k, v in enumerate(adapter_state):
adapter_state[k] = torch.cat([v] * 2, dim=0) adapter_state[k] = torch.cat([v] * 2, dim=0)
...@@ -804,7 +849,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -804,7 +849,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -812,13 +857,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -812,13 +857,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=[state.clone() for state in adapter_state], down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
......
...@@ -670,6 +670,46 @@ class StableDiffusionXLAdapterPipeline( ...@@ -670,6 +670,46 @@ class StableDiffusionXLAdapterPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -882,6 +922,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -882,6 +922,8 @@ class StableDiffusionXLAdapterPipeline(
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
) )
self._guidance_scale = guidance_scale
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -892,11 +934,6 @@ class StableDiffusionXLAdapterPipeline( ...@@ -892,11 +934,6 @@ class StableDiffusionXLAdapterPipeline(
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
( (
prompt_embeds, prompt_embeds,
...@@ -908,7 +945,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -908,7 +945,7 @@ class StableDiffusionXLAdapterPipeline(
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -939,6 +976,14 @@ class StableDiffusionXLAdapterPipeline( ...@@ -939,6 +976,14 @@ class StableDiffusionXLAdapterPipeline(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Prepare added time ids & embeddings & adapter features # 7. Prepare added time ids & embeddings & adapter features
if isinstance(self.adapter, MultiAdapter): if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
...@@ -951,7 +996,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -951,7 +996,7 @@ class StableDiffusionXLAdapterPipeline(
if num_images_per_prompt > 1: if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state): for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
for k, v in enumerate(adapter_state): for k, v in enumerate(adapter_state):
adapter_state[k] = torch.cat([v] * 2, dim=0) adapter_state[k] = torch.cat([v] * 2, dim=0)
...@@ -979,7 +1024,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -979,7 +1024,7 @@ class StableDiffusionXLAdapterPipeline(
else: else:
negative_add_time_ids = add_time_ids negative_add_time_ids = add_time_ids
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
...@@ -1005,7 +1050,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1005,7 +1050,7 @@ class StableDiffusionXLAdapterPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -1021,6 +1066,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1021,6 +1066,7 @@ class StableDiffusionXLAdapterPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
...@@ -1028,11 +1074,11 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1028,11 +1074,11 @@ class StableDiffusionXLAdapterPipeline(
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
......
...@@ -27,6 +27,7 @@ from diffusers import ( ...@@ -27,6 +27,7 @@ from diffusers import (
ControlNetModel, ControlNetModel,
DDIMScheduler, DDIMScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -116,7 +117,7 @@ class ControlNetPipelineFastTests( ...@@ -116,7 +117,7 @@ class ControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(4, 8), block_out_channels=(4, 8),
...@@ -128,6 +129,7 @@ class ControlNetPipelineFastTests( ...@@ -128,6 +129,7 @@ class ControlNetPipelineFastTests(
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
norm_num_groups=1, norm_num_groups=1,
time_cond_proj_dim=time_cond_proj_dim,
) )
torch.manual_seed(0) torch.manual_seed(0)
controlnet = ControlNetModel( controlnet = ControlNetModel(
...@@ -221,6 +223,28 @@ class ControlNetPipelineFastTests( ...@@ -221,6 +223,28 @@ class ControlNetPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_controlnet_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionControlNetPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionMultiControlNetPipelineFastTests( class StableDiffusionMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
......
...@@ -24,6 +24,7 @@ from diffusers import ( ...@@ -24,6 +24,7 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
ControlNetModel, ControlNetModel,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -62,7 +63,7 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -62,7 +63,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -80,6 +81,7 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -80,6 +81,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64, cross_attention_dim=64,
time_cond_proj_dim=time_cond_proj_dim,
) )
torch.manual_seed(0) torch.manual_seed(0)
controlnet = ControlNetModel( controlnet = ControlNetModel(
...@@ -330,6 +332,26 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -330,6 +332,26 @@ class StableDiffusionXLControlNetPipelineFastTests(
# make sure that it's equal # make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
def test_controlnet_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLControlNetPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.7799, 0.614, 0.6162, 0.7082, 0.6662, 0.5833, 0.4148, 0.5182, 0.4866])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionXLMultiControlNetPipelineFastTests( class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
......
...@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
LCMScheduler,
MultiAdapter, MultiAdapter,
PNDMScheduler, PNDMScheduler,
StableDiffusionAdapterPipeline, StableDiffusionAdapterPipeline,
...@@ -56,7 +57,7 @@ class AdapterTests: ...@@ -56,7 +57,7 @@ class AdapterTests:
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self, adapter_type): def get_dummy_components(self, adapter_type, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -67,6 +68,7 @@ class AdapterTests: ...@@ -67,6 +68,7 @@ class AdapterTests:
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
time_cond_proj_dim=time_cond_proj_dim,
) )
scheduler = PNDMScheduler(skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -264,13 +266,13 @@ class AdapterTests: ...@@ -264,13 +266,13 @@ class AdapterTests:
@parameterized.expand( @parameterized.expand(
[ [
# (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8). # (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8).
((4 * 8 + 1) * 8), (((4 * 8 + 1) * 8),),
# (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16). # (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16).
((4 * 4 + 1) * 16), (((4 * 4 + 1) * 16),),
# (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32). # (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32).
((4 * 2 + 1) * 32), (((4 * 2 + 1) * 32),),
# (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64). # (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64).
((4 * 1 + 1) * 64), (((4 * 1 + 1) * 64),),
] ]
) )
def test_multiple_image_dimensions(self, dim): def test_multiple_image_dimensions(self, dim):
...@@ -292,10 +294,30 @@ class AdapterTests: ...@@ -292,10 +294,30 @@ class AdapterTests:
assert image.shape == (1, dim, dim, 3) assert image.shape == (1, dim, dim, 3)
def test_adapter_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4535, 0.5493, 0.4359, 0.5452, 0.6086, 0.4441, 0.5544, 0.501, 0.4859])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("full_adapter") return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim)
def get_dummy_components_with_full_downscaling(self): def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("full_adapter") return super().get_dummy_components_with_full_downscaling("full_adapter")
...@@ -317,8 +339,8 @@ class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMi ...@@ -317,8 +339,8 @@ class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMi
class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("light_adapter") return super().get_dummy_components("light_adapter", time_cond_proj_dim=time_cond_proj_dim)
def get_dummy_components_with_full_downscaling(self): def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("light_adapter") return super().get_dummy_components_with_full_downscaling("light_adapter")
...@@ -340,8 +362,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM ...@@ -340,8 +362,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter") return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
def get_dummy_components_with_full_downscaling(self): def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("multi_adapter") return super().get_dummy_components_with_full_downscaling("multi_adapter")
......
...@@ -26,6 +26,7 @@ import diffusers ...@@ -26,6 +26,7 @@ import diffusers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
MultiAdapter, MultiAdapter,
StableDiffusionXLAdapterPipeline, StableDiffusionXLAdapterPipeline,
T2IAdapter, T2IAdapter,
...@@ -59,7 +60,7 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -59,7 +60,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self, adapter_type="full_adapter_xl"): def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -77,6 +78,7 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -77,6 +78,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64, cross_attention_dim=64,
time_cond_proj_dim=time_cond_proj_dim,
) )
scheduler = EulerDiscreteScheduler( scheduler = EulerDiscreteScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -309,9 +311,9 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -309,9 +311,9 @@ class StableDiffusionXLAdapterPipelineFastTests(
@parameterized.expand( @parameterized.expand(
[ [
# (dim=144) The internal feature map will be 9x9 after initial pixel unshuffling (downscaled x16). # (dim=144) The internal feature map will be 9x9 after initial pixel unshuffling (downscaled x16).
((4 * 2 + 1) * 16), (((4 * 2 + 1) * 16),),
# (dim=160) The internal feature map will be 5x5 after the first T2I down block (downscaled x32). # (dim=160) The internal feature map will be 5x5 after the first T2I down block (downscaled x32).
((4 * 1 + 1) * 32), (((4 * 1 + 1) * 32),),
] ]
) )
def test_multiple_image_dimensions(self, dim): def test_multiple_image_dimensions(self, dim):
...@@ -367,12 +369,32 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -367,12 +369,32 @@ class StableDiffusionXLAdapterPipelineFastTests(
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
return self._test_save_load_optional_components() return self._test_save_load_optional_components()
def test_adapter_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionXLMultiAdapterPipelineFastTests( class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
): ):
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter") return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
def get_dummy_components_with_full_downscaling(self): def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("multi_adapter") return super().get_dummy_components_with_full_downscaling("multi_adapter")
...@@ -569,6 +591,29 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -569,6 +591,29 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
if test_mean_pixel_difference: if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0]) assert_mean_pixel_difference(output_batch[0][0], output[0][0])
def test_adapter_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment