Unverified Commit 58933059 authored by Lukas Struppek's avatar Lukas Struppek Committed by GitHub
Browse files

VersatileDiffusion: fix input processing (#1568)



* fix versatile diffusion input

* merge main

* `make fix-copies`
Co-authored-by: default avataranton- <anton@huggingface.co>
parent 31444f57
...@@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
and not isinstance(image, list) and not isinstance(image, list)
): ):
raise ValueError( raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
) )
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
......
...@@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
and not isinstance(image, list) and not isinstance(image, list)
): ):
raise ValueError( raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
) )
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
......
...@@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds return embeds
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
prompt = [p for p in prompt]
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings # get prompt text embeddings
...@@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
def check_inputs(self, image, height, width, callback_steps): def check_inputs(self, image, height, width, callback_steps):
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): if (
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment