Unverified Commit cb0bf0bd authored by Kevin Turner's avatar Kevin Turner Committed by GitHub
Browse files

fix(DDIM scheduler): use correct dtype for noise (#742)

Otherwise, it crashes when eta > 0 with float16.
parent e0fece2b
......@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, generator=generator).to(device)
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment