Unverified Commit 47bf8e56 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

can call encode_prompt with out setting a text encoder instance variable (#4396)

* can call encode_prompt with out setting a text encoder instance variable

* fix
parent 18fc40c1
...@@ -334,7 +334,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -334,7 +334,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -390,7 +397,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -390,7 +397,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -335,7 +335,14 @@ class AltDiffusionImg2ImgPipeline( ...@@ -335,7 +335,14 @@ class AltDiffusionImg2ImgPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -391,7 +398,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -391,7 +398,7 @@ class AltDiffusionImg2ImgPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -326,7 +326,14 @@ class StableDiffusionControlNetPipeline( ...@@ -326,7 +326,14 @@ class StableDiffusionControlNetPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -382,7 +389,7 @@ class StableDiffusionControlNetPipeline( ...@@ -382,7 +389,7 @@ class StableDiffusionControlNetPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -352,7 +352,14 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -352,7 +352,14 @@ class StableDiffusionControlNetImg2ImgPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -408,7 +415,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -408,7 +415,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -469,7 +469,14 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -469,7 +469,14 @@ class StableDiffusionControlNetInpaintPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -525,7 +532,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -525,7 +532,7 @@ class StableDiffusionControlNetInpaintPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -346,7 +346,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -346,7 +346,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -402,7 +409,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -402,7 +409,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -336,7 +336,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -336,7 +336,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -392,7 +399,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -392,7 +399,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -335,7 +335,14 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -335,7 +335,14 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -391,7 +398,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -391,7 +398,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -220,7 +220,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -220,7 +220,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -276,7 +283,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -276,7 +283,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -521,7 +521,14 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -521,7 +521,14 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -577,7 +584,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -577,7 +584,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -340,7 +340,14 @@ class StableDiffusionImg2ImgPipeline( ...@@ -340,7 +340,14 @@ class StableDiffusionImg2ImgPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -396,7 +403,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -396,7 +403,7 @@ class StableDiffusionImg2ImgPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -398,7 +398,14 @@ class StableDiffusionInpaintPipeline( ...@@ -398,7 +398,14 @@ class StableDiffusionInpaintPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -454,7 +461,7 @@ class StableDiffusionInpaintPipeline( ...@@ -454,7 +461,7 @@ class StableDiffusionInpaintPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -336,7 +336,14 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -336,7 +336,14 @@ class StableDiffusionInpaintPipelineLegacy(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -392,7 +399,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -392,7 +399,7 @@ class StableDiffusionInpaintPipelineLegacy(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -243,7 +243,14 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -243,7 +243,14 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -299,7 +306,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -299,7 +306,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -305,7 +305,14 @@ class StableDiffusionLDM3DPipeline( ...@@ -305,7 +305,14 @@ class StableDiffusionLDM3DPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -361,7 +368,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -361,7 +368,7 @@ class StableDiffusionLDM3DPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -247,7 +247,14 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -247,7 +247,14 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -303,7 +310,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -303,7 +310,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -224,7 +224,14 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -224,7 +224,14 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -280,7 +287,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -280,7 +287,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -289,7 +289,14 @@ class StableDiffusionParadigmsPipeline( ...@@ -289,7 +289,14 @@ class StableDiffusionParadigmsPipeline(
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -345,7 +352,7 @@ class StableDiffusionParadigmsPipeline( ...@@ -345,7 +352,7 @@ class StableDiffusionParadigmsPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -479,7 +479,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -479,7 +479,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -535,7 +542,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -535,7 +542,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
...@@ -247,7 +247,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -247,7 +247,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -303,7 +310,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -303,7 +310,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment