run_ddpm.py 540 Bytes
Newer Older
1
2
3
#!/usr/bin/env python3
import torch

Patrick von Platen's avatar
improve  
Patrick von Platen committed
4
from diffusers import GaussianDDPMScheduler, UNetModel
5
6


7
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
8

Patrick von Platen's avatar
improve  
Patrick von Platen committed
9
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1")  # number of steps  # L1 or L2
10
11
12
13
14
15
16
17

training_images = torch.randn(8, 3, 128, 128)  # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training

sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape  # (4, 3, 128, 128)