"examples/community/stable_diffusion_ipex.py" did not exist on "0e82fb19e16bd2d45ade31c9a4b871de56e7e80a"
Unverified Commit 06bc1daf authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Type hint] Karras VE pipeline (#288)



* [Type hint] Karras VE pipeline

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 7e1b202d
#!/usr/bin/env python3
import warnings
from typing import Optional
import torch
......@@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline):
unet: UNet2DModel
scheduler: KarrasVeScheduler
def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment