Unverified Commit c009c203 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[SDXL] Fix uncaught error with image to image (#8856)

* initial commit

* apply suggestion to sdxl pipelines

* apply fix to sd pipelines
parent 3f141176
......@@ -824,6 +824,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -930,6 +930,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -528,6 +528,13 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -520,6 +520,13 @@ class LatentConsistencyModelImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -719,6 +719,13 @@ class StableDiffusionXLPAGImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -494,6 +494,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -740,6 +740,13 @@ class StableDiffusionImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
......@@ -710,6 +710,13 @@ class StableDiffusionXLImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment