Commit 78e99a99 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

adapt run.py

parent fc67917a
...@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model ...@@ -11,6 +11,8 @@ from models import ddpm as ddpm_model
from models import layerspp from models import layerspp
from models import layers from models import layers
from models import normalization from models import normalization
from models.ema import ExponentialMovingAverage
from losses import get_optimizer
from utils import restore_checkpoint from utils import restore_checkpoint
...@@ -27,6 +29,7 @@ import datasets ...@@ -27,6 +29,7 @@ import datasets
import torch import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0) torch.manual_seed(0)
...@@ -81,7 +84,6 @@ torch.manual_seed(0) ...@@ -81,7 +84,6 @@ torch.manual_seed(0)
class NewReverseDiffusionPredictor: class NewReverseDiffusionPredictor:
def __init__(self, sde, score_fn, probability_flow=False): def __init__(self, sde, score_fn, probability_flow=False):
super().__init__() super().__init__()
self.sde = sde self.sde = sde
...@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor: ...@@ -112,7 +114,6 @@ class NewReverseDiffusionPredictor:
class NewLangevinCorrector: class NewLangevinCorrector:
def __init__(self, sde, score_fn, snr, n_steps): def __init__(self, sde, score_fn, snr, n_steps):
super().__init__() super().__init__()
self.sde = sde self.sde = sde
...@@ -146,28 +147,19 @@ class NewLangevinCorrector: ...@@ -146,28 +147,19 @@ class NewLangevinCorrector:
def save_image(x): def save_image(x):
# image_processed = x.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
# 6. save image
image_pil.save("../images/hey.png") image_pil.save("../images/hey.png")
#x = np.load("cifar10.npy")
#
#save_image(x)
# @title Load the score-based model
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if sde.lower() == 'vesde': if sde.lower() == 'vesde':
from configs.ve import cifar10_ncsnpp_continuous as configs # from configs.ve import cifar10_ncsnpp_continuous as configs
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# from configs.ve import ffhq_ncsnpp_continuous as configs from configs.ve import ffhq_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
config = configs.get_config() config = configs.get_config()
config.model.num_scales = 1000 config.model.num_scales = 2
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5 sampling_eps = 1e-5
elif sde.lower() == 'vpsde': elif sde.lower() == 'vpsde':
...@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size ...@@ -189,32 +181,53 @@ config.eval.batch_size = batch_size
random_seed = 0 #@param {"type": "integer"} random_seed = 0 #@param {"type": "integer"}
score_model = mutils.create_model(config) #sigmas = mutils.get_sigmas(config)
#scaler = datasets.get_data_scaler(config)
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
from diffusers import NCSNpp
score_model = NCSNpp(config).to(config.device)
score_model = torch.nn.DataParallel(score_model)
loaded_state = torch.load(ckpt_filename) loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt")
score_model.load_state_dict(loaded_state["model"], strict=False) del loaded_state["module.sigmas"]
score_model.load_state_dict(loaded_state, strict=False)
inverse_scaler = datasets.get_data_inverse_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config)
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
def image_grid(x):
size = config.data.image_size
channels = config.data.num_channels
img = x.reshape(-1, size, size, channels)
w = int(np.sqrt(img.shape[0]))
img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
return img
#@title PC sampling #@title PC sampling
img_size = config.data.image_size img_size = config.data.image_size
channels = config.data.num_channels channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size) shape = (batch_size, channels, img_size, img_size)
probability_flow = False probability_flow = False
snr = 0.16 #@param {"type": "number"} snr = 0.15 #@param {"type": "number"}
n_steps = 1#@param {"type": "integer"} n_steps = 1#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors.""" """A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
...@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn, ...@@ -253,14 +266,14 @@ corrector_update_fn = functools.partial(shared_corrector_update_fn,
snr=snr, snr=snr,
n_steps=n_steps) n_steps=n_steps)
device = "cuda" device = config.device
model = score_model.to(device) model = score_model
denoise = False denoise = True
new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps)
new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model)
#
with torch.no_grad(): with torch.no_grad():
# Initial sample # Initial sample
x = sde.prior_sampling(shape).to(device) x = sde.prior_sampling(shape).to(device)
...@@ -269,21 +282,32 @@ with torch.no_grad(): ...@@ -269,21 +282,32 @@ with torch.no_grad():
for i in range(sde.N): for i in range(sde.N):
t = timesteps[i] t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model) # x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = predictor_update_fn(x, vec_t, model=model) # x, x_mean = predictor_update_fn(x, vec_t, model=model)
# x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_corrector.update_fn(x, vec_t)
# x, x_mean = new_predictor.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t)
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
save_image(x)
# for 5 #save_image(x)
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" # for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758
# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504
# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
# for 1000
assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment