Unverified Commit 1d04e1b4 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Continuation of #942: additional float64 failure (#996)

* Add failing test for #940.

* Do not use torch.float64 in mps.

* style

* Temporarily skip add_noise for IPNDMScheduler.

Until #990 is addressed.

* Fix additional float64 error in mps.

* Improve add_noise test

* Slight edit – I think it's clearer this way.
parent a23ad87d
...@@ -252,8 +252,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -252,8 +252,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) # mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
......
...@@ -266,13 +266,14 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -266,13 +266,14 @@ class SchedulerCommonTest(unittest.TestCase):
continue continue
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(100)
sample = self.dummy_sample.to(torch_device) sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0) scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape) self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device) noise = torch.randn_like(scaled_sample).to(torch_device)
t = torch.tensor([10]).to(torch_device) t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t) noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape) self.assertEqual(noised.shape, scaled_sample.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment