Unverified Commit 5dda1735 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Inference support for `mps` device (#355)

* Initial support for mps in Stable Diffusion pipeline.

* Initial "warmup" implementation when using mps.

* Make some deterministic tests pass with mps.

* Disable training tests when using mps.

* SD: generate latents in CPU then move to device.

This is especially important when using the mps device, because
generators are not supported there. See for example
https://github.com/pytorch/pytorch/issues/84288.

In addition, the other pipelines seem to use the same approach: generate
the random samples then move to the appropriate device.

After this change, generating an image in MPS produces the same result
as when using the CPU, if the same seed is used.

* Remove prints.

* Pass AutoencoderKL test_output_pretrained with mps.

Sampling from `posterior` must be done in CPU.

* Style

* Do not use torch.long for log op in mps device.

* Perform incompatible padding ops in CPU.

UNet tests now pass.
See https://github.com/pytorch/pytorch/issues/84535



* Style: fix import order.

* Remove unused symbols.

* Remove MPSWarmupMixin, do not apply automatically.

We do apply warmup in the tests, but not during normal use.
This adopts some PR suggestions by @patrickvonplaten.

* Add comment for mps fallback to CPU step.

* Add README_mps.md for mps installation and use.

* Apply `black` to modified files.

* Restrict README_mps to SD, show measures in table.

* Make PNDM indexing compatible with mps.

Addresses #239.

* Do not use float64 when using LDMScheduler.

Fixes #358.

* Fix typo identified by @patil-suraj
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Adapt example to new output style.

* Restore 1:1 results reproducibility with CompVis.

However, mps latents need to be generated in CPU because generators
don't work in the mps device.

* Move PyTorch nightly to requirements.

* Adapt `test_scheduler_outputs_equivalence` ton MPS.

* mps: skip training tests instead of ignoring silently.

* Make VQModel tests pass on mps.

* mps ddim tests: warmup, increase tolerance.

* ScoreSdeVeScheduler indexing made mps compatible.

* Make ldm pipeline tests pass using warmup.

* Style

* Simplify casting as suggested in PR.

* Add Known Issues to readme.

* `isort` import order.

* Remove _mps_warmup helpers from ModelMixin.

And just make changes to the tests.

* Skip tests using unittest decorator for consistency.

* Remove temporary var.

* Remove spurious blank space.

* Remove unused symbol.

* Remove README_mps.
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> 
parent 98f34683
...@@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4 ...@@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4
conda install -c conda-forge diffusers conda install -c conda-forge diffusers
``` ```
**Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
## Contributing ## Contributing
We ❤️ contributions from the open-source community! We ❤️ contributions from the open-source community!
......
...@@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module):
self.attn2._slice_size = slice_size self.attn2._slice_size = slice_size
def forward(self, x, context=None): def forward(self, x, context=None):
x = x.contiguous() if x.device.type == "mps" else x
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
......
...@@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): ...@@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor) out = input.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out[ out = out[
:, :,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
......
...@@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps.to(dtype=torch.float32)
timesteps = timesteps[None].to(device=sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0]) timesteps = timesteps.expand(sample.shape[0])
......
...@@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object): ...@@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object):
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
x = self.mean + self.std * sample
return x return x
def kl(self, other=None): def kl(self, other=None):
......
...@@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput): ...@@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput):
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json" config_name = "model_index.json"
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
......
...@@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None: if latents is None:
latents = torch.randn( latents = torch.randn(
latents_shape, latents_shape,
generator=generator, generator=generator,
device=self.device, device=latents_device,
) )
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device) latents = latents.to(self.device)
# set timesteps # set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
......
...@@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray], noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor: ) -> torch.Tensor:
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
......
...@@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt": elif tensor_format == "pt":
return torch.where( return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
) )
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
...@@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
) # torch.repeat_interleave(timestep, sample.shape[0]) ) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long() timesteps = (timestep * (len(self.timesteps) - 1)).long()
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.discrete_sigmas.device)
sigma = self.discrete_sigmas[timesteps].to(sample.device) sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample) drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import inspect import inspect
import tempfile import tempfile
import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import torch_device from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
...@@ -38,6 +40,11 @@ class ModelTesterMixin: ...@@ -38,6 +40,11 @@ class ModelTesterMixin:
new_model.to(torch_device) new_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict) image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image.sample image = image.sample
...@@ -55,7 +62,12 @@ class ModelTesterMixin: ...@@ -55,7 +62,12 @@ class ModelTesterMixin:
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
first = model(**inputs_dict) first = model(**inputs_dict)
if isinstance(first, dict): if isinstance(first, dict):
first = first.sample first = first.sample
...@@ -132,6 +144,7 @@ class ModelTesterMixin: ...@@ -132,6 +144,7 @@ class ModelTesterMixin:
self.assertEqual(output_1.shape, output_2.shape) self.assertEqual(output_1.shape, output_2.shape)
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_training(self): def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -147,6 +160,7 @@ class ModelTesterMixin: ...@@ -147,6 +160,7 @@ class ModelTesterMixin:
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_ema_training(self): def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -167,8 +181,13 @@ class ModelTesterMixin: ...@@ -167,8 +181,13 @@ class ModelTesterMixin:
def test_scheduler_outputs_equivalence(self): def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t): def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0 t[t != t] = 0
return t return t.to(device)
def recursive_check(tuple_object, dict_object): def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)): if isinstance(tuple_object, (List, Tuple)):
...@@ -198,7 +217,12 @@ class ModelTesterMixin: ...@@ -198,7 +217,12 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs_dict = model(**inputs_dict) with torch.no_grad():
outputs_tuple = model(**inputs_dict, return_dict=False) # Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict) recursive_check(outputs_tuple, outputs_dict)
...@@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels = 3 num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step} return {"sample": noise, "timestep": time_step}
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import torch import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import floats_tensor, torch_device from diffusers.testing_utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin
...@@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model = model.to(torch_device) model = model.to(torch_device)
model.eval() model.eval()
# One-time warmup pass (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
_ = model(image, sample_posterior=True).sample
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
......
...@@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device) image = image.to(torch_device)
with torch.no_grad(): with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = model(image)
output = model(image).sample output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu() output_slice = output[0, -1, -3:, -3:].flatten().cpu()
......
...@@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase): ...@@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase):
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
...@@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase): ...@@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
) )
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
...@@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase): ...@@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[
"sample"
]
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[ image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
"sample" "sample"
...@@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase): ...@@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase):
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment