Unverified Commit 0343d8f5 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Do not use torch.float64 on the mps device (#942)

* Add failing test for #940.

* Do not use torch.float64 in mps.

* style

* Temporarily skip add_noise for IPNDMScheduler.

Until #990 is addressed.
parent 4b9f5895
......@@ -252,7 +252,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
self.timesteps = self.timesteps.to(original_samples.device)
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
......
......@@ -27,6 +27,7 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
)
from diffusers.utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -258,6 +259,23 @@ class SchedulerCommonTest(unittest.TestCase):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
def test_add_noise_device(self):
for scheduler_class in self.scheduler_classes:
if scheduler_class == IPNDMScheduler:
# Skip until #990 is addressed
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
t = torch.tensor([10]).to(torch_device)
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment