Commit ef4365c6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent addc43af
#!/usr/bin/env python3 #!/usr/bin/env python3
import tempfile import tempfile
import sys import sys
import os
import pathlib
from modeling_ddpm import DDPM from modeling_ddpm import DDPM
import PIL.Image
import numpy as np
model_id = sys.argv[1] model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-lsun-cifar10", "ddpm-lsun-celeba-hq", "ddpm-lsun-celeba-hq-ema"]
ddpm = DDPM.from_pretrained(model_id) for model_id in model_ids:
image = ddpm()
import PIL.Image path = os.path.join("/home/patrick/images/hf", model_id)
import numpy as np pathlib.Path(path).mkdir(parents=True, exist_ok=True)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5 ddpm = DDPM.from_pretrained("fusing/" + model_id)
image_processed = image_processed.numpy().astype(np.uint8) image = ddpm(batch_size=4)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
import ipdb; ipdb.set_trace() for i in range(image_processed.shape[0]):
image_pil = PIL.Image.fromarray(image_processed[i])
image_pil.save(os.path.join(path, f"image_{i}.png"))
...@@ -33,7 +33,7 @@ class DDPM(DiffusionPipeline): ...@@ -33,7 +33,7 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
......
...@@ -108,7 +108,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -108,7 +108,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
def sample_variance(self, time_step, shape, device, generator=None): def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step] variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1) nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator) noise = self.sample_noise(shape, device=device, generator=generator)
......
...@@ -76,7 +76,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): ...@@ -76,7 +76,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class ModelTesterMixin(unittest.TestCase): class ModelTesterMixin(unittest.TestCase):
@property @property
def dummy_input(self): def dummy_input(self):
batch_size = 1 batch_size = 4
num_channels = 3 num_channels = 3
sizes = (32, 32) sizes = (32, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment