"vscode:/vscode.git/clone" did not exist on "619b9658e286ed10560a13f80084e286a6d85956"
Unverified Commit 727434c2 authored by Partho's avatar Partho Committed by GitHub
Browse files

Accept latents as optional input in Latent Diffusion pipeline (#1723)

* Latent Diffusion pipeline accept latents

* make style

* check for mps

randn does not work reproducibly on mps
parent 21e61eb3
...@@ -72,6 +72,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -72,6 +72,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 1.0, guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
...@@ -96,6 +97,10 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -96,6 +97,10 @@ class LDMTextToImagePipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic. deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -130,10 +135,21 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -130,10 +135,21 @@ class LDMTextToImagePipeline(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device)
else:
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8), latents_shape,
generator=generator, generator=generator,
device=self.device,
) )
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device) latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment