Unverified Commit 7b628a22 authored by Partho's avatar Partho Committed by GitHub
Browse files

[Type hint] PNDM pipeline (#327)

* [Type hint] PNDM pipeline

* ran make style

* Revert "ran make style" wrong black version
parent 033b77eb
......@@ -15,20 +15,33 @@
import warnings
from typing import Optional
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import PNDMScheduler
class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
unet: UNet2DModel
scheduler: PNDMScheduler
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment