"src/vscode:/vscode.git/clone" did not exist on "d5dd8df3b4e978e3f0549fa35e245a8785116af7"
Commit d5c527a4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

clean up

parent 135acd83
#!/usr/bin/env python3 #!/usr/bin/env python3
import numpy as np
import torch import torch
import PIL
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline): class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler): def __init__(self, model, scheduler):
super().__init__() super().__init__()
...@@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device) model = self.model.to(device)
centered = False # TODO(Patrick) move to scheduler config
n_steps = 1 n_steps = 1
x = torch.randn(*shape) * self.scheduler.config.sigma_max x = torch.randn(*shape) * self.scheduler.config.sigma_max
...@@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline):
x, x_mean = self.scheduler.step_pred(result, x, t) x, x_mean = self.scheduler.step_pred(result, x, t)
x = x_mean return x_mean
if centered:
x = (x + 1.0) / 2.0
return x
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
# x_sum = 3382810112.0
# x_mean = 1075.366455078125
#
#
# def check_x_sum_x_mean(x, x_sum, x_mean):
# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
# check_x_sum_x_mean(x, x_sum, x_mean)
#
#
# def save_image(x):
# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("../images/hey.png")
#
#
# save_image(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment