Unverified Commit 90250d9e authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Cast height, width to int inside prepare latents (#7691)

update
parent e5674015
...@@ -652,7 +652,12 @@ class StableDiffusionPipeline( ...@@ -652,7 +652,12 @@ class StableDiffusionPipeline(
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -224,7 +224,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi ...@@ -224,7 +224,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -795,7 +795,12 @@ class StableDiffusionInpaintPipeline( ...@@ -795,7 +795,12 @@ class StableDiffusionInpaintPipeline(
return_noise=False, return_noise=False,
return_image_latents=False, return_image_latents=False,
): ):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -730,7 +730,12 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -730,7 +730,12 @@ class StableDiffusionInstructPix2PixPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -876,7 +876,12 @@ class StableUnCLIPPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInver ...@@ -876,7 +876,12 @@ class StableUnCLIPPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInver
# 11. Prepare latent variables # 11. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
latents = self.prepare_latents( latents = self.prepare_latents(
shape=shape, shape=shape,
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
......
...@@ -543,7 +543,12 @@ class StableUnCLIPImg2ImgPipeline( ...@@ -543,7 +543,12 @@ class StableUnCLIPImg2ImgPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -581,7 +581,12 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM ...@@ -581,7 +581,12 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -740,7 +740,12 @@ class StableDiffusionDiffEditPipeline( ...@@ -740,7 +740,12 @@ class StableDiffusionDiffEditPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -476,7 +476,12 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -476,7 +476,12 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -500,7 +500,12 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -500,7 +500,12 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -441,7 +441,12 @@ class StableDiffusionKDiffusionPipeline( ...@@ -441,7 +441,12 @@ class StableDiffusionKDiffusionPipeline(
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
......
...@@ -497,7 +497,12 @@ class StableDiffusionXLKDiffusionPipeline( ...@@ -497,7 +497,12 @@ class StableDiffusionXLKDiffusionPipeline(
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -627,7 +627,12 @@ class StableDiffusionLDM3DPipeline( ...@@ -627,7 +627,12 @@ class StableDiffusionLDM3DPipeline(
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -638,7 +638,12 @@ class StableDiffusionPanoramaPipeline( ...@@ -638,7 +638,12 @@ class StableDiffusionPanoramaPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -416,7 +416,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda ...@@ -416,7 +416,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -535,7 +535,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua ...@@ -535,7 +535,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -685,7 +685,12 @@ class StableDiffusionXLPipeline( ...@@ -685,7 +685,12 @@ class StableDiffusionXLPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -880,7 +880,12 @@ class StableDiffusionXLInpaintPipeline( ...@@ -880,7 +880,12 @@ class StableDiffusionXLInpaintPipeline(
return_noise=False, return_noise=False,
return_image_latents=False, return_image_latents=False,
): ):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -487,7 +487,12 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -487,7 +487,12 @@ class StableDiffusionXLInstructPix2PixPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -569,7 +569,12 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -569,7 +569,12 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment