Unverified Commit 813744e5 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

MPS schedulers: don't use float64 (#1169)

* Schedulers: don't use float64 on mps

* Test set_timesteps() on device (float schedulers).

* SD pipeline: use device in set_timesteps.

* SD in-painting pipeline: use device in set_timesteps.

* Tests: fix mps crashes.

* Skip test_load_pipeline_from_git on mps.

Not compatible with float16.

* Use device.type instead of str in Euler schedulers.
parent 5a8b3569
...@@ -360,12 +360,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -360,12 +360,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device) latents = latents.to(self.device)
# set timesteps # set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
......
...@@ -416,12 +416,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -416,12 +416,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
# set timesteps # set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
......
...@@ -151,6 +151,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -151,6 +151,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step( def step(
...@@ -217,8 +221,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -217,8 +221,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
device = model_output.device if torch.is_tensor(model_output) else "cpu" device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if str(device) == "mps": if device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device device
......
...@@ -152,6 +152,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -152,6 +152,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step( def step(
...@@ -214,8 +218,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -214,8 +218,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device if torch.is_tensor(model_output) else "cpu" device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
if str(device) == "mps": if device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device device
......
...@@ -173,7 +173,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -173,7 +173,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = [] self.derivatives = []
......
...@@ -456,6 +456,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -456,6 +456,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed) latents = self.get_latents(seed)
...@@ -507,6 +508,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -507,6 +508,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed) latents = self.get_latents(seed)
...@@ -558,6 +560,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -558,6 +560,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64)) latents = self.get_latents(seed, shape=(4, 9, 64, 64))
......
...@@ -41,7 +41,7 @@ from diffusers import ( ...@@ -41,7 +41,7 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -124,7 +124,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -124,7 +124,7 @@ class CustomPipelineTests(unittest.TestCase):
assert output_str == "This is a local test" assert output_str == "This is a local test"
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") @require_torch_gpu
def test_load_pipeline_from_git(self): def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
......
...@@ -83,8 +83,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -83,8 +83,8 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step) time_step = float(time_step)
...@@ -1192,6 +1192,31 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1192,6 +1192,31 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3 assert abs(result_mean.item() - 1.31) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3
class EulerDiscreteSchedulerTest(SchedulerCommonTest): class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerDiscreteScheduler,) scheduler_classes = (EulerDiscreteScheduler,)
...@@ -1248,6 +1273,34 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1248,6 +1273,34 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3 assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator().manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (EulerAncestralDiscreteScheduler,) scheduler_classes = (EulerAncestralDiscreteScheduler,)
...@@ -1303,6 +1356,38 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1303,6 +1356,38 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 152.3192) < 1e-2 assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3 assert abs(result_mean.item() - 0.1983) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator().manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
if not str(torch_device).startswith("mps"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
# Larger tolerance on mps
assert abs(result_mean.item() - 0.1983) < 1e-2
class IPNDMSchedulerTest(SchedulerCommonTest): class IPNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (IPNDMScheduler,) scheduler_classes = (IPNDMScheduler,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment