Unverified Commit 6e099e2c authored by Tanishq Abraham's avatar Tanishq Abraham Committed by GitHub
Browse files

add num_inference_steps arg to DDPM (#935)

parent 82044153
......@@ -42,6 +42,7 @@ class DDPMPipeline(DiffusionPipeline):
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
......@@ -53,6 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -73,7 +77,7 @@ class DDPMPipeline(DiffusionPipeline):
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(1000)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment