example.py 952 Bytes
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
#!/usr/bin/env python3
Patrick von Platen's avatar
up  
Patrick von Platen committed
2
3
import os
import pathlib
Patrick von Platen's avatar
Patrick von Platen committed
4

Patrick von Platen's avatar
up  
Patrick von Platen committed
5
import numpy as np
Patrick von Platen's avatar
improve  
Patrick von Platen committed
6

Patrick von Platen's avatar
Patrick von Platen committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import PIL.Image
from modeling_ddpm import DDPM


model_ids = [
    "ddpm-lsun-cat",
    "ddpm-lsun-cat-ema",
    "ddpm-lsun-church-ema",
    "ddpm-lsun-church",
    "ddpm-lsun-bedroom",
    "ddpm-lsun-bedroom-ema",
    "ddpm-cifar10-ema",
    "ddpm-cifar10",
    "ddpm-celeba-hq",
    "ddpm-celeba-hq-ema",
]
Patrick von Platen's avatar
up  
Patrick von Platen committed
23

Patrick von Platen's avatar
up  
Patrick von Platen committed
24
25
26
27
28
29
30
31
32
33
for model_id in model_ids:
    path = os.path.join("/home/patrick/images/hf", model_id)
    pathlib.Path(path).mkdir(parents=True, exist_ok=True)

    ddpm = DDPM.from_pretrained("fusing/" + model_id)
    image = ddpm(batch_size=4)

    image_processed = image.cpu().permute(0, 2, 3, 1)
    image_processed = (image_processed + 1.0) * 127.5
    image_processed = image_processed.numpy().astype(np.uint8)
Patrick von Platen's avatar
Patrick von Platen committed
34

Patrick von Platen's avatar
up  
Patrick von Platen committed
35
36
37
    for i in range(image_processed.shape[0]):
        image_pil = PIL.Image.fromarray(image_processed[i])
        image_pil.save(os.path.join(path, f"image_{i}.png"))