Unverified Commit 5a8b3569 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[DDIMScheduler] fix noise device in ddim step (#1189)

* fix noise device in ddim sched

* fix typo

* self.device -> device

* remove duplicated if

* use str device

* don't use str for device
parent 20a05d6a
......@@ -288,7 +288,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
......@@ -296,9 +296,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
if variance_noise is None:
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
device
)
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment