"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "256e0106749363fce06c28000698edeaf56a874d"
Unverified Commit 17528afc authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix styling issues (#5699)



* up

* up

* up

* Empty-Commit

* fix keyword argument call.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 78be4007
...@@ -206,17 +206,15 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -206,17 +206,15 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds prompt_embeds = prior_text_encoder_output.text_embeds
prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state text_enc_hid_states = prior_text_encoder_output.last_hidden_state
else: else:
batch_size = text_model_output[0].shape[0] batch_size = text_model_output[0].shape[0]
prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
num_images_per_prompt, dim=0
)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -235,9 +233,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -235,9 +233,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
) )
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
uncond_prior_text_encoder_hidden_states = ( uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
...@@ -245,11 +241,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -245,11 +241,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_prior_text_encoder_hidden_states.shape[1] seq_len = uncond_text_enc_hid_states.shape[1]
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
1, num_images_per_prompt, 1 uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
)
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1 batch_size * num_images_per_prompt, seq_len, -1
) )
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -260,13 +254,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -260,13 +254,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prior_text_encoder_hidden_states = torch.cat( text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
[uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
)
text_mask = torch.cat([uncond_text_mask, text_mask]) text_mask = torch.cat([uncond_text_mask, text_mask])
return prompt_embeds, prior_text_encoder_hidden_states, text_mask return prompt_embeds, text_enc_hid_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt( def _encode_prompt(
......
...@@ -156,15 +156,15 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -156,15 +156,15 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state text_enc_hid_states = text_encoder_output.last_hidden_state
else: else:
batch_size = text_model_output[0].shape[0] batch_size = text_model_output[0].shape[0]
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -181,7 +181,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -181,7 +181,7 @@ class UnCLIPPipeline(DiffusionPipeline):
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
...@@ -189,9 +189,9 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -189,9 +189,9 @@ class UnCLIPPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1] seq_len = uncond_text_enc_hid_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1 batch_size * num_images_per_prompt, seq_len, -1
) )
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -202,11 +202,11 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -202,11 +202,11 @@ class UnCLIPPipeline(DiffusionPipeline):
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask]) text_mask = torch.cat([uncond_text_mask, text_mask])
return prompt_embeds, text_encoder_hidden_states, text_mask return prompt_embeds, text_enc_hid_states, text_mask
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -293,7 +293,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -293,7 +293,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
) )
...@@ -321,7 +321,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -321,7 +321,7 @@ class UnCLIPPipeline(DiffusionPipeline):
latent_model_input, latent_model_input,
timestep=t, timestep=t,
proj_embedding=prompt_embeds, proj_embedding=prompt_embeds,
encoder_hidden_states=text_encoder_hidden_states, encoder_hidden_states=text_enc_hid_states,
attention_mask=text_mask, attention_mask=text_mask,
).predicted_image_embedding ).predicted_image_embedding
...@@ -352,10 +352,10 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -352,10 +352,10 @@ class UnCLIPPipeline(DiffusionPipeline):
# decoder # decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states, text_encoder_hidden_states=text_enc_hid_states,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
) )
...@@ -377,7 +377,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -377,7 +377,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), (batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype, text_enc_hid_states.dtype,
device, device,
generator, generator,
decoder_latents, decoder_latents,
...@@ -391,7 +391,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -391,7 +391,7 @@ class UnCLIPPipeline(DiffusionPipeline):
noise_pred = self.decoder( noise_pred = self.decoder(
sample=latent_model_input, sample=latent_model_input,
timestep=t, timestep=t,
encoder_hidden_states=text_encoder_hidden_states, encoder_hidden_states=text_enc_hid_states,
class_labels=additive_clip_time_embeddings, class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask, attention_mask=decoder_text_mask,
).sample ).sample
......
...@@ -1494,7 +1494,6 @@ class ResnetBlockFlat(nn.Module): ...@@ -1494,7 +1494,6 @@ class ResnetBlockFlat(nn.Module):
return output_tensor return output_tensor
# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module): class DownBlockFlat(nn.Module):
def __init__( def __init__(
self, self,
...@@ -1583,7 +1582,6 @@ class DownBlockFlat(nn.Module): ...@@ -1583,7 +1582,6 @@ class DownBlockFlat(nn.Module):
return hidden_states, output_states return hidden_states, output_states
# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module): class CrossAttnDownBlockFlat(nn.Module):
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment