Unverified Commit f96c42fc authored by Huy Do's avatar Huy Do Committed by GitHub
Browse files

Re-enable vision MPS builds (#8485)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f1bcbd31
...@@ -54,7 +54,7 @@ jobs: ...@@ -54,7 +54,7 @@ jobs:
export GPU_ARCH_TYPE=cpu export GPU_ARCH_TYPE=cpu
export GPU_ARCH_VERSION='' export GPU_ARCH_VERSION=''
./.github/scripts/cmake.sh ${CONDA_RUN} ./.github/scripts/cmake.sh
windows: windows:
strategy: strategy:
......
...@@ -68,7 +68,7 @@ jobs: ...@@ -68,7 +68,7 @@ jobs:
export GPU_ARCH_TYPE=cpu export GPU_ARCH_TYPE=cpu
export GPU_ARCH_VERSION='' export GPU_ARCH_VERSION=''
./.github/scripts/unittest.sh ${CONDA_RUN} ./.github/scripts/unittest.sh
unittests-windows: unittests-windows:
strategy: strategy:
......
...@@ -5,7 +5,6 @@ import os ...@@ -5,7 +5,6 @@ import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
import warnings
import torch import torch
from pkg_resources import DistributionNotFound, get_distribution, parse_version from pkg_resources import DistributionNotFound, get_distribution, parse_version
...@@ -139,6 +138,7 @@ def get_extensions(): ...@@ -139,6 +138,7 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
) )
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:") print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1" force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
...@@ -204,15 +204,8 @@ def get_extensions(): ...@@ -204,15 +204,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)] define_macros += [("WITH_HIP", None)]
nvcc_flags = [] nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
# FIXME: MPS build breaks custom ops registration, so it was disabled. sources += source_mps
# See https://github.com/pytorch/vision/issues/8456.
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
if force_mps:
warnings.warn("MPS build is temporarily disabled!!!!")
# elif torch.backends.mps.is_available() or force_mps:
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
# sources += source_mps
if sys.platform == "win32": if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)] define_macros += [("torchvision_EXPORTS", None)]
......
...@@ -49,8 +49,7 @@ def pytest_collection_modifyitems(items): ...@@ -49,8 +49,7 @@ def pytest_collection_modifyitems(items):
# There are special cases though, see below # There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
# TODO: uncoment when MPS works again - see FIXME in setup.py if needs_mps and not torch.backends.mps.is_available():
if needs_mps: # and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
if IN_FBCODE: if IN_FBCODE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment