Unverified Commit 84d7faeb authored by Adrian Punga's avatar Adrian Punga Committed by GitHub
Browse files

Fix support for MPS in KDPM2AncestralDiscreteScheduler (#6365)

Fix support for MPS

MPS doesn't support float64
parent 4c483deb
...@@ -277,7 +277,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -277,7 +277,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
if str(device).startswith("mps"):
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device) timesteps = torch.from_numpy(timesteps).to(device)
sigmas_interpol = sigmas_interpol.cpu() sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu() log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array( timesteps_interpol = np.array(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment