"vscode:/vscode.git/clone" did not exist on "deed945625418b1f2625048e22350e528f796cbc"
Unverified Commit 3d7eaf83 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

LCM Add Tests (#5707)

* lcm add tests

* uP

* Fix all

* uP

* Add

* all

* uP

* uP

* uP

* uP

* uP

* uP

* uP
parent bf406ea8
...@@ -1411,6 +1411,11 @@ class LoraLoaderMixin: ...@@ -1411,6 +1411,11 @@ class LoraLoaderMixin:
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
) )
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1: if len(targeted_files) > 1:
raise ValueError( raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
......
...@@ -588,6 +588,34 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -588,6 +588,34 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -605,7 +633,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -605,7 +633,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -804,6 +832,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -804,6 +832,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -818,6 +854,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -818,6 +854,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -646,6 +646,34 @@ class AltDiffusionImg2ImgPipeline( ...@@ -646,6 +646,34 @@ class AltDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -659,7 +687,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -659,7 +687,7 @@ class AltDiffusionImg2ImgPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -849,6 +877,14 @@ class AltDiffusionImg2ImgPipeline( ...@@ -849,6 +877,14 @@ class AltDiffusionImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -863,6 +899,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -863,6 +899,7 @@ class AltDiffusionImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -576,6 +576,35 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -576,6 +576,35 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -593,7 +622,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -593,7 +622,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -790,6 +819,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -790,6 +819,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -804,6 +841,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -804,6 +841,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -640,6 +640,35 @@ class StableDiffusionImg2ImgPipeline( ...@@ -640,6 +640,35 @@ class StableDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -653,7 +682,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -653,7 +682,7 @@ class StableDiffusionImg2ImgPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -841,6 +870,14 @@ class StableDiffusionImg2ImgPipeline( ...@@ -841,6 +870,14 @@ class StableDiffusionImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -855,6 +892,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -855,6 +892,7 @@ class StableDiffusionImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -765,6 +765,35 @@ class StableDiffusionInpaintPipeline( ...@@ -765,6 +765,35 @@ class StableDiffusionInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -778,7 +807,7 @@ class StableDiffusionInpaintPipeline( ...@@ -778,7 +807,7 @@ class StableDiffusionInpaintPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -1087,6 +1116,14 @@ class StableDiffusionInpaintPipeline( ...@@ -1087,6 +1116,14 @@ class StableDiffusionInpaintPipeline(
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 10. Denoising loop # 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -1106,6 +1143,7 @@ class StableDiffusionInpaintPipeline( ...@@ -1106,6 +1143,7 @@ class StableDiffusionInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -636,6 +636,35 @@ class StableDiffusionXLPipeline( ...@@ -636,6 +636,35 @@ class StableDiffusionXLPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -653,7 +682,7 @@ class StableDiffusionXLPipeline( ...@@ -653,7 +682,7 @@ class StableDiffusionXLPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -989,6 +1018,14 @@ class StableDiffusionXLPipeline( ...@@ -989,6 +1018,14 @@ class StableDiffusionXLPipeline(
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
# 9. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -1003,6 +1040,7 @@ class StableDiffusionXLPipeline( ...@@ -1003,6 +1040,7 @@ class StableDiffusionXLPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
......
...@@ -763,6 +763,35 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -763,6 +763,35 @@ class StableDiffusionXLImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -780,7 +809,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -780,7 +809,7 @@ class StableDiffusionXLImg2ImgPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -1156,6 +1185,15 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1156,6 +1185,15 @@ class StableDiffusionXLImg2ImgPipeline(
) )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
# 9.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -1170,6 +1208,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1170,6 +1208,7 @@ class StableDiffusionXLImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
......
...@@ -982,6 +982,35 @@ class StableDiffusionXLInpaintPipeline( ...@@ -982,6 +982,35 @@ class StableDiffusionXLInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -999,7 +1028,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -999,7 +1028,7 @@ class StableDiffusionXLInpaintPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@property @property
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property @property
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
...@@ -1464,6 +1493,14 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1464,6 +1493,14 @@ class StableDiffusionXLInpaintPipeline(
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
# 11.1 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -1482,6 +1519,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1482,6 +1519,7 @@ class StableDiffusionXLInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
......
...@@ -28,10 +28,12 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -28,10 +28,12 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
AutoPipelineForImage2Image,
ControlNetModel, ControlNetModel,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
...@@ -107,10 +109,12 @@ class PeftLoraLoaderMixinTests: ...@@ -107,10 +109,12 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None unet_kwargs = None
vae_kwargs = None vae_kwargs = None
def get_dummy_components(self): def get_dummy_components(self, scheduler_cls=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs) unet = UNet2DConditionModel(**self.unet_kwargs)
scheduler = self.scheduler_cls(**self.scheduler_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs) vae = AutoencoderKL(**self.vae_kwargs)
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
...@@ -200,746 +204,806 @@ class PeftLoraLoaderMixinTests: ...@@ -200,746 +204,806 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
components, _, _, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs).images output_no_lora = pipe(**inputs).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
""" """
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
if self.has_two_text_encoders: pipe.text_encoder.add_adapter(text_lora_config)
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" self.assertTrue(
) self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
""" """
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
if self.has_two_text_encoders: pipe.text_encoder.add_adapter(text_lora_config)
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue( pipe.text_encoder_2.add_adapter(text_lora_config)
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" self.assertTrue(
) self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora_scale = pipe( output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} self.assertTrue(
).images not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
self.assertTrue( )
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
output_lora_0_scale = pipe( output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images ).images
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + scale should change the output",
) )
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
pipe.fuse_lora() if self.has_two_text_encoders:
# Fusing should still keep the LoRA layers pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders: pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertFalse( self.assertTrue(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_lora_unloaded(self): def test_simple_inference_with_text_lora_unloaded(self):
""" """
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
if self.has_two_text_encoders: pipe.text_encoder.add_adapter(text_lora_config)
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
pipe.unload_lora_weights() if self.has_two_text_encoders:
# unloading should remove the LoRA layers pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertFalse( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
if self.has_two_text_encoders: pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
) )
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue( self.assertFalse(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" self.check_if_lora_correctly_set(pipe.text_encoder_2),
) "Lora not correctly unloaded in text encoder 2",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders: if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.pipeline_class.save_lora_weights( self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images self.pipeline_class.save_lora_weights(
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
if self.has_two_text_encoders: self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue( if self.has_two_text_encoders:
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), self.assertTrue(
"Loading from saved checkpoints should give same results.", self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_save_pretrained(self): def test_simple_inference_save_pretrained(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
components, _, text_lora_config, _ = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
if self.has_two_text_encoders: pipe.text_encoder.add_adapter(text_lora_config)
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
with tempfile.TemporaryDirectory() as tmpdirname: images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) with tempfile.TemporaryDirectory() as tmpdirname:
pipe_from_pretrained.to(self.torch_device) pipe.save_pretrained(tmpdirname)
self.assertTrue( pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), pipe_from_pretrained.to(self.torch_device)
"Lora not correctly set in text encoder",
)
if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
"Lora not correctly set in text encoder 2", "Lora not correctly set in text encoder",
) )
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue(
self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
self.assertTrue( images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", self.assertTrue(
) np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_unet_lora_save_load(self): def test_simple_inference_with_text_unet_lora_save_load(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.unet.add_adapter(unet_lora_config) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") pipe.unet.add_adapter(unet_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders: if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.pipeline_class.save_lora_weights( self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.unload_lora_weights()
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
) )
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_unet_lora_and_scale(self): def test_simple_inference_with_text_unet_lora_and_scale(self):
""" """
Tests a simple inference with lora attached on the text encoder + Unet + scale argument Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.text_encoder.add_adapter(text_lora_config) if self.has_two_text_encoders:
pipe.unet.add_adapter(unet_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders: output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_lora_scale = pipe(
self.assertTrue( **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ).images
) self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
output_lora_scale = pipe( "Lora + scale should change the output",
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} )
).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
output_lora_0_scale = pipe( output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images ).images
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
) )
self.assertTrue( self.assertTrue(
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"The scaling parameter has not been correctly restored!", "The scaling parameter has not been correctly restored!",
) )
def test_simple_inference_with_text_lora_unet_fused(self): def test_simple_inference_with_text_lora_unet_fused(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.unet.add_adapter(unet_lora_config) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") pipe.unet.add_adapter(unet_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.fuse_lora() if self.has_two_text_encoders:
# Fusing should still keep the LoRA layers pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet") self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders: pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertFalse( self.assertTrue(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_unet_lora_unloaded(self): def test_simple_inference_with_text_unet_lora_unloaded(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.unload_lora_weights() if self.has_two_text_encoders:
# unloading should remove the LoRA layers pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertFalse( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
if self.has_two_text_encoders: pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
) )
self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images if self.has_two_text_encoders:
self.assertTrue( self.assertFalse(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" self.check_if_lora_correctly_set(pipe.text_encoder_2),
) "Lora not correctly unloaded in text encoder 2",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_unet_lora_unfused(self): def test_simple_inference_with_text_unet_lora_unfused(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config) pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.fuse_lora() if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.fuse_lora()
pipe.unfuse_lora() output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.unfuse_lora()
# unloading should remove the LoRA layers
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders: output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
# unloading should remove the LoRA layers
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
# Fuse and unfuse should lead to the same results if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3), self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
"Fused lora should change the output", )
)
# Fuse and unfuse should lead to the same results
self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_unet_multi_adapter(self): def test_simple_inference_with_text_unet_multi_adapter(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.set_adapters("adapter-1") if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.set_adapters("adapter-1")
pipe.set_adapters("adapter-2") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.set_adapters(["adapter-1", "adapter-2"])
# Fuse and unfuse should lead to the same results output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse( # Fuse and unfuse should lead to the same results
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), self.assertFalse(
"Adapter 1 and mixed adapters should give different results", np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
) "Adapter 1 and 2 should give different results",
)
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
pipe.disable_lora() self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.disable_lora()
self.assertTrue( output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", self.assertTrue(
) np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_unet_multi_adapter_weighted(self): def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.set_adapters("adapter-1") if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.set_adapters("adapter-1")
pipe.set_adapters("adapter-2") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.set_adapters(["adapter-1", "adapter-2"])
# Fuse and unfuse should lead to the same results output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse( # Fuse and unfuse should lead to the same results
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), self.assertFalse(
"Adapter 1 and mixed adapters should give different results", np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
) "Adapter 1 and 2 should give different results",
)
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) self.assertFalse(
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
self.assertFalse( pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
"Weighted adapter and mixed adapter should give different results",
)
pipe.disable_lora() self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results",
)
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images pipe.disable_lora()
self.assertTrue( output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_lora_fuse_nan(self): self.assertTrue(
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
pipe = self.pipeline_class(**components) "output with no lora and output with lora disabled should give same results",
pipe = pipe.to(self.torch_device) )
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") def test_lora_fuse_nan(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
# corrupt one LoRA weight with `inf` values self.assertTrue(
with torch.no_grad(): self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
# with `safe_fusing=True` we should see an Error # with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe.fuse_lora(safe_fusing=True) pipe.fuse_lora(safe_fusing=True)
# without we should not see an error, but every image will be black # without we should not see an error, but every image will be black
pipe.fuse_lora(safe_fusing=False) pipe.fuse_lora(safe_fusing=False)
out = pipe("test", num_inference_steps=2, output_type="np").images out = pipe("test", num_inference_steps=2, output_type="np").images
self.assertTrue(np.isnan(out).all()) self.assertTrue(np.isnan(out).all())
def test_get_adapters(self): def test_get_adapters(self):
""" """
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"]) self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"]) self.assertListEqual(adapter_names, ["adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"]) self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self): def test_get_list_adapters(self):
""" """
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_list_adapters() adapter_names = pipe.get_list_adapters()
self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]}) self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_list_adapters() adapter_names = pipe.get_list_adapters()
self.assertDictEqual( self.assertDictEqual(
adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]} adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
) )
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertDictEqual( self.assertDictEqual(
pipe.get_list_adapters(), {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]} pipe.get_list_adapters(),
) {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]},
)
pipe.unet.add_adapter(unet_lora_config, "adapter-3") pipe.unet.add_adapter(unet_lora_config, "adapter-3")
self.assertDictEqual( self.assertDictEqual(
pipe.get_list_adapters(), pipe.get_list_adapters(),
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]}, {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
) )
@unittest.skip("This is failing for now - need to investigate") @unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
...@@ -947,32 +1011,35 @@ class PeftLoraLoaderMixinTests: ...@@ -947,32 +1011,35 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() for scheduler_cls in [DDIMScheduler, LCMScheduler]:
pipe = self.pipeline_class(**components) components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = pipe.to(self.torch_device) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe = pipe.to(self.torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config) pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) if self.has_two_text_encoders:
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works.. # Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0)).images _ = pipe(**inputs, generator=torch.manual_seed(0)).images
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
...@@ -1574,6 +1641,97 @@ class LoraSDXLIntegrationTests(unittest.TestCase): ...@@ -1574,6 +1641,97 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe) release_memory(pipe)
def test_sdxl_lcm_lora(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdxl"
pipe.load_lora_weights(lora_model_id)
image = pipe(
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
).images[0]
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdxl_lcm_lora.png"
)
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
release_memory(pipe)
def test_sdv1_5_lcm_lora(self):
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
pipe.load_lora_weights(lora_model_id)
image = pipe(
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
).images[0]
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora.png"
)
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
release_memory(pipe)
def test_sdv1_5_lcm_lora_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
)
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
pipe.load_lora_weights(lora_model_id)
image = pipe(
"snowy mountain",
generator=generator,
image=init_image,
strength=0.5,
num_inference_steps=4,
guidance_scale=0.5,
).images[0]
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora_img2img.png"
)
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
release_memory(pipe)
def test_sdxl_1_0_lora_fusion(self): def test_sdxl_1_0_lora_fusion(self):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
......
...@@ -31,6 +31,7 @@ from diffusers import ( ...@@ -31,6 +31,7 @@ from diffusers import (
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
...@@ -41,6 +42,7 @@ from diffusers.models.attention_processor import AttnProcessor ...@@ -41,6 +42,7 @@ from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
enable_full_determinism, enable_full_determinism,
load_image,
load_numpy, load_numpy,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -107,12 +109,13 @@ class StableDiffusionPipelineFastTests( ...@@ -107,12 +109,13 @@ class StableDiffusionPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(4, 8), block_out_channels=(4, 8),
layers_per_block=1, layers_per_block=1,
sample_size=32, sample_size=32,
time_cond_proj_dim=time_cond_proj_dim,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
...@@ -196,6 +199,26 @@ class StableDiffusionPipelineFastTests( ...@@ -196,6 +199,26 @@ class StableDiffusionPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_prompt_embeds(self): def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components) sd_pipe = StableDiffusionPipeline(**components)
...@@ -1066,6 +1089,29 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -1066,6 +1089,29 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
inputs["seed"] = seed inputs["seed"] = seed
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs) run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
def test_stable_diffusion_lcm(self):
unet = UNet2DConditionModel.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", subfolder="unet")
sd_pipe = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", unet=unet).to(torch_device)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 6
inputs["output_type"] = "pil"
image = sd_pipe(**inputs).images[0]
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_lcm.png"
)
image = sd_pipe.image_processor.pil_to_numpy(image)
expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
assert max_diff < 1e-2
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -36,7 +36,6 @@ from diffusers.utils.testing_utils import ( ...@@ -36,7 +36,6 @@ from diffusers.utils.testing_utils import (
load_numpy, load_numpy,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
print_tensor_test,
require_torch_gpu, require_torch_gpu,
slow, slow,
torch_device, torch_device,
...@@ -202,7 +201,6 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase): ...@@ -202,7 +201,6 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
assert image.shape == (1, 512, 512, 3) assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138]) expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138])
print_tensor_test(image_slice)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 1e-4 assert max_diff < 1e-4
......
...@@ -28,6 +28,7 @@ from diffusers import ( ...@@ -28,6 +28,7 @@ from diffusers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
LCMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
...@@ -103,11 +104,12 @@ class StableDiffusionImg2ImgPipelineFastTests( ...@@ -103,11 +104,12 @@ class StableDiffusionImg2ImgPipelineFastTests(
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
...@@ -187,6 +189,23 @@ class StableDiffusionImg2ImgPipelineFastTests( ...@@ -187,6 +189,23 @@ class StableDiffusionImg2ImgPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img2img_default_case_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.5709, 0.4614, 0.4587, 0.5978, 0.5298, 0.6910, 0.6240, 0.5212, 0.5454])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img2img_negative_prompt(self): def test_stable_diffusion_img2img_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
......
...@@ -29,6 +29,7 @@ from diffusers import ( ...@@ -29,6 +29,7 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
LCMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
...@@ -106,10 +107,11 @@ class StableDiffusionInpaintPipelineFastTests( ...@@ -106,10 +107,11 @@ class StableDiffusionInpaintPipelineFastTests(
image_latents_params = frozenset([]) image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
time_cond_proj_dim=time_cond_proj_dim,
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
in_channels=9, in_channels=9,
...@@ -206,6 +208,23 @@ class StableDiffusionInpaintPipelineFastTests( ...@@ -206,6 +208,23 @@ class StableDiffusionInpaintPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_image_tensor(self): def test_stable_diffusion_inpaint_image_tensor(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -288,11 +307,12 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli ...@@ -288,11 +307,12 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
image_params = frozenset([]) image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
...@@ -381,6 +401,23 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli ...@@ -381,6 +401,23 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_2_images(self): def test_stable_diffusion_inpaint_2_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
......
...@@ -27,12 +27,20 @@ from diffusers import ( ...@@ -27,12 +27,20 @@ from diffusers import (
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
LCMScheduler,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..pipeline_params import ( from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS,
...@@ -56,11 +64,12 @@ class StableDiffusionXLPipelineFastTests( ...@@ -56,11 +64,12 @@ class StableDiffusionXLPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
def get_dummy_components(self): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(2, 4), block_out_channels=(2, 4),
layers_per_block=2, layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
...@@ -155,6 +164,23 @@ class StableDiffusionXLPipelineFastTests( ...@@ -155,6 +164,23 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_euler_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_prompt_embeds(self): def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components) sd_pipe = StableDiffusionXLPipeline(**components)
...@@ -890,3 +916,32 @@ class StableDiffusionXLPipelineFastTests( ...@@ -890,3 +916,32 @@ class StableDiffusionXLPipelineFastTests(
image_slices.append(image[0, -3:, -3:, -1].flatten()) image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_lcm(self):
torch.manual_seed(0)
unet = UNet2DConditionModel.from_pretrained(
"latent-consistency/lcm-ssd-1b", torch_dtype=torch.float16, variant="fp16"
)
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
"segmind/SSD-1B", unet=unet, torch_dtype=torch.float16, variant="fp16"
).to(torch_device)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "a red car standing on the side of the street"
image = sd_pipe(prompt, num_inference_steps=4, guidance_scale=8.0).images[0]
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_ssd_1b_lcm.png"
)
image = sd_pipe.image_processor.pil_to_numpy(image)
expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
assert max_diff < 1e-2
...@@ -24,6 +24,7 @@ from diffusers import ( ...@@ -24,6 +24,7 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
AutoencoderTiny, AutoencoderTiny,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -57,7 +58,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -57,7 +58,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"} {"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
) )
def get_dummy_components(self, skip_first_text_encoder=False): def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -65,6 +66,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -65,6 +66,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below # SD2-specific config below
...@@ -172,6 +174,24 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -172,6 +174,24 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_img2img_euler_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
......
...@@ -28,6 +28,7 @@ from diffusers import ( ...@@ -28,6 +28,7 @@ from diffusers import (
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
LCMScheduler,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
UNet2DConditionModel, UNet2DConditionModel,
UniPCMultistepScheduler, UniPCMultistepScheduler,
...@@ -61,7 +62,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -61,7 +62,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
} }
) )
def get_dummy_components(self, skip_first_text_encoder=False): def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -69,6 +70,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -69,6 +70,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below # SD2-specific config below
...@@ -209,6 +211,24 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -209,6 +211,24 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_inpaint_euler_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment