Unverified Commit e86a280c authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Remove warning about half precision on MPS (#1163)

Remove warning about half precision on MPS.
parent b4a1ed85
...@@ -209,13 +209,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -209,13 +209,13 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys(): for name in module_names.keys():
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
logger.warning( logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It" "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make" " is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for" " sure to use an accelerator to run the pipeline in inference, due to the lack of"
" `float16` operations on those devices in PyTorch. Please remove the" " support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference." " `torch_dtype=torch.float16` argument, or use another device for inference."
) )
module.to(torch_device) module.to(torch_device)
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment