Unverified Commit 5f25818a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

allow custom height, width in StableDiffusionPipeline (#179)

* allow custom height width

* raise if height width are not mul of 8
parent c25d8c90
...@@ -28,6 +28,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -28,6 +28,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0, guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
...@@ -45,6 +47,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -45,6 +47,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device) self.unet.to(torch_device)
self.vae.to(torch_device) self.vae.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
...@@ -72,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -72,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
# get the intial random noise # get the intial random noise
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment