Commit addc43af authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct modeling_ddpm

parent f9a4532f
...@@ -27,8 +27,9 @@ class DDPM(DiffusionPipeline): ...@@ -27,8 +27,9 @@ class DDPM(DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, generator=None, torch_device=None): def __call__(self, batch_size=1, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) self.unet.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment