Unverified Commit 1d04e1b4 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Continuation of #942: additional float64 failure (#996)

* Add failing test for #940.

* Do not use torch.float64 in mps.

* style

* Temporarily skip add_noise for IPNDMScheduler.

Until #990 is addressed.

* Fix additional float64 error in mps.

* Improve add_noise test

* Slight edit – I think it's clearer this way.
parent a23ad87d
......@@ -252,9 +252,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
timesteps = timesteps.to(original_samples.device)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
......
......@@ -266,13 +266,14 @@ class SchedulerCommonTest(unittest.TestCase):
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(100)
sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
t = torch.tensor([10]).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment