Commit 1dc856e5 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

ddpm scheduler variance fixes

parent 2cbdc586
......@@ -214,16 +214,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
variance = torch.clamp(variance, min=1e-20)
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = torch.clamp(variance, min=1e-20)
variance = variance
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.log(variance, min=1e-20)
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = current_beta_t
......@@ -234,7 +235,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return predicted_variance
elif variance_type == "learned_range":
min_log = torch.log(variance)
max_log = torch.log(self.betas[t])
max_log = torch.log(current_beta_t)
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment