"vscode:/vscode.git/clone" did not exist on "81e6345ddf893c594f6d76406a388fa012cb0a29"
Unverified Commit c009c203 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[SDXL] Fix uncaught error with image to image (#8856)

* initial commit

* apply suggestion to sdxl pipelines

* apply fix to sd pipelines
parent 3f141176
...@@ -824,6 +824,13 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -824,6 +824,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -930,6 +930,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -930,6 +930,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -528,6 +528,13 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu ...@@ -528,6 +528,13 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -520,6 +520,13 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -520,6 +520,13 @@ class LatentConsistencyModelImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -719,6 +719,13 @@ class StableDiffusionXLPAGImg2ImgPipeline( ...@@ -719,6 +719,13 @@ class StableDiffusionXLPAGImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -494,6 +494,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -494,6 +494,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -740,6 +740,13 @@ class StableDiffusionImg2ImgPipeline( ...@@ -740,6 +740,13 @@ class StableDiffusionImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
...@@ -710,6 +710,13 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -710,6 +710,13 @@ class StableDiffusionXLImg2ImgPipeline(
) )
elif isinstance(generator, list): elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [ init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size) for i in range(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment