Unverified Commit aa2ce41b authored by NotNANtoN's avatar NotNANtoN Committed by GitHub
Browse files

Fix img2img speed with LMS-Discrete Scheduler (#896)



Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the  `integrate.quad` call later on- by long I mean more than 10x slower.
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 81fa2d68
...@@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.FloatTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64 # mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else: else:
self.timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = self.sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment