run_ddpm.py 524 Bytes
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
#!/usr/bin/env python3
import torch

Patrick von Platen's avatar
Patrick von Platen committed
4
from diffusers import DDPMScheduler, UNetModel
Patrick von Platen's avatar
Patrick von Platen committed
5
6
7
8


model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))

Patrick von Platen's avatar
Patrick von Platen committed
9
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1")  # number of steps  # L1 or L2
Patrick von Platen's avatar
Patrick von Platen committed
10
11
12
13
14
15
16
17

training_images = torch.randn(8, 3, 128, 128)  # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training

sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape  # (4, 3, 128, 128)