Unverified Commit 58933059 authored by Lukas Struppek's avatar Lukas Struppek Committed by GitHub
Browse files

VersatileDiffusion: fix input processing (#1568)



* fix versatile diffusion input

* merge main

* `make fix-copies`
Co-authored-by: default avataranton- <anton@huggingface.co>
parent 31444f57
......@@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
......
......@@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
......
......@@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
prompt = [p for p in prompt]
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
......@@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
def check_inputs(self, image, height, width, callback_steps):
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment