Unverified Commit 4393f7df authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove broken MPS build (#8472)

parent b6770a7e
......@@ -5,6 +5,7 @@ import os
import shutil
import subprocess
import sys
import warnings
import torch
from pkg_resources import DistributionNotFound, get_distribution, parse_version
......@@ -138,7 +139,6 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
......@@ -204,8 +204,15 @@ def get_extensions():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps
# FIXME: MPS build breaks custom ops registration, so it was disabled.
# See https://github.com/pytorch/vision/issues/8456.
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
if force_mps:
warnings.warn("MPS build is temporarily disabled!!!!")
# elif torch.backends.mps.is_available() or force_mps:
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
# sources += source_mps
if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment