"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b4e6dc3037e75e2dc16466b914b1be597b06be9a"
Unverified Commit 5164c9fa authored by Santiago Víquez's avatar Santiago Víquez Committed by GitHub
Browse files

[Type hint] Score SDE VE pipeline (#325)

parent 93debd30
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings import warnings
from typing import Optional
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler
class ScoreSdeVePipeline(DiffusionPipeline): class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
unet: UNet2DModel
scheduler: ScoreSdeVeScheduler
def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs): def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 2000,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
warnings.warn( warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment