".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5cd45c24bf616f09c818455184f3d1c3a3cebe00"
Unverified Commit 5a8b3569 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[DDIMScheduler] fix noise device in ddim step (#1189)

* fix noise device in ddim sched

* fix typo

* self.device -> device

* remove duplicated if

* use str device

* don't use str for device
parent 20a05d6a
...@@ -288,7 +288,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -288,7 +288,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if eta > 0: if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu" device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if variance_noise is not None and generator is not None: if variance_noise is not None and generator is not None:
raise ValueError( raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or" "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
...@@ -296,9 +296,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -296,9 +296,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
if variance_noise is None: if variance_noise is None:
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to( if device.type == "mps":
device # randn does not work reproducibly on mps
) variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment