Unverified Commit 22c1ba56 authored by psychedelicious's avatar psychedelicious Committed by GitHub
Browse files

Fix k_dpm_2 & k_dpm_2_a on MPS (#2241)

Needed to convert `timesteps` to `float32` a bit sooner.

Fixes #1537
parent 7386e773
......@@ -161,16 +161,16 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = timesteps
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
self.sample = None
......
......@@ -149,18 +149,17 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps).to(device)
if str(device).startswith("mps"):
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(torch.float32)
else:
self.timesteps = timesteps
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
self.sample = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment