Commit fc67917a authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent 7ca832ca
......@@ -269,20 +269,21 @@ with torch.no_grad():
for i in range(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = predictor_update_fn(x, vec_t, model=model)
# x, x_mean = new_corrector.update_fn(x, vec_t)
# x, x_mean = new_predictor.update_fn(x, vec_t)
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
save_image(x)
# for 5
#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong"
#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong"
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
# for 1000
assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong"
assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong"
assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
save_image(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment