Unverified Commit 3383f774 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update the clip guided PR according to the new API (#751)

parent df9c0701
...@@ -175,6 +175,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -175,6 +175,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
width: Optional[int] = 512, width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1,
clip_guidance_scale: Optional[float] = 100, clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None, clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4, num_cutouts: Optional[int] = 4,
...@@ -203,6 +204,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -203,6 +204,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
if clip_guidance_scale > 0: if clip_guidance_scale > 0:
if clip_prompt is not None: if clip_prompt is not None:
...@@ -217,6 +220,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -217,6 +220,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_text_input = text_input.input_ids.to(self.device) clip_text_input = text_input.input_ids.to(self.device)
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
# duplicate text embeddings clip for each generation per prompt
text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -225,10 +230,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -225,10 +230,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1] max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
...@@ -240,14 +245,16 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -240,14 +245,16 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
latents = torch.randn( if self.device.type == "mps":
latents_shape, # randn does not exist on mps
generator=generator, latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
device=latents_device, self.device
) )
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
...@@ -261,17 +268,17 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -261,17 +268,17 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas # Some schedulers like PNDM have timesteps as arrays
if isinstance(self.scheduler, LMSDiscreteScheduler): # It's more optimized to move all timesteps to correct device beforehand
latents = latents * self.scheduler.sigmas[0] timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
...@@ -299,9 +306,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -299,9 +306,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents).prev_sample latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment