"...text-generation-inference.git" did not exist on "3011639ff7a6db7e6aaa5506ff516b9df8bc443e"
Unverified Commit cb0bf0bd authored by Kevin Turner's avatar Kevin Turner Committed by GitHub
Browse files

fix(DDIM scheduler): use correct dtype for noise (#742)

Otherwise, it crashes when eta > 0 with float16.
parent e0fece2b
...@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0: if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu" device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, generator=generator).to(device) noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment