Unverified Commit e30e1b89 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support one-string prompts and custom image size in LDM (#212)

* Support one-string prompts in LDM

* Add other features from SD too
parent df90f0ce
import inspect
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -24,20 +24,30 @@ class LDMTextToImagePipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
output_type="pil",
prompt: Union[str, List[str]],
height: Optional[int] = 256,
width: Optional[int] = 256,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = len(prompt)
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device)
self.vqvae.to(torch_device)
......@@ -53,7 +63,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
......
......@@ -854,7 +854,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment