"src/vscode:/vscode.git/clone" did not exist on "20e92586c1fda968ea3343ba0f44f2b21f3c09d2"
Unverified Commit e30e1b89 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support one-string prompts and custom image size in LDM (#212)

* Support one-string prompts in LDM

* Add other features from SD too
parent df90f0ce
import inspect import inspect
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -24,20 +24,30 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -24,20 +24,30 @@ class LDMTextToImagePipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt, prompt: Union[str, List[str]],
batch_size=1, height: Optional[int] = 256,
generator=None, width: Optional[int] = 256,
torch_device=None, num_inference_steps: Optional[int] = 50,
eta=0.0, guidance_scale: Optional[float] = 1.0,
guidance_scale=1.0, eta: Optional[float] = 0.0,
num_inference_steps=50, generator: Optional[torch.Generator] = None,
output_type="pil", torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = len(prompt)
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
...@@ -53,7 +63,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -53,7 +63,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0] text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(torch_device)
......
...@@ -854,7 +854,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -854,7 +854,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment