Unverified Commit ca571485 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Disable C++ models from being compiled without explicit request (#1535)

* Disable C++ models from being compiled without explicitly being asked for

* Fix import in tests, which are already disabled
parent 6a991f8b
...@@ -88,6 +88,8 @@ def get_extensions(): ...@@ -88,6 +88,8 @@ def get_extensions():
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
if compile_cpp_tests:
test_dir = os.path.join(this_dir, 'test') test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp')) test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
...@@ -96,6 +98,7 @@ def get_extensions(): ...@@ -96,6 +98,7 @@ def get_extensions():
test_file = [os.path.join(test_dir, s) for s in test_file] test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models] source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models tests = test_file + source_models
tests_include_dirs = [test_dir, models_dir]
define_macros = [] define_macros = []
...@@ -123,7 +126,6 @@ def get_extensions(): ...@@ -123,7 +126,6 @@ def get_extensions():
sources = [os.path.join(extensions_dir, s) for s in sources] sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir] include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir]
ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
has_ffmpeg = ffmpeg_exe is not None has_ffmpeg = ffmpeg_exe is not None
...@@ -143,15 +145,18 @@ def get_extensions(): ...@@ -143,15 +145,18 @@ def get_extensions():
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
), )
]
if compile_cpp_tests:
ext_modules.append(
extension( extension(
'torchvision._C_tests', 'torchvision._C_tests',
tests, tests,
include_dirs=tests_include_dirs, include_dirs=tests_include_dirs,
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
), )
] )
if has_ffmpeg: if has_ffmpeg:
ext_modules.append( ext_modules.append(
CppExtension( CppExtension(
......
import torch import torch
import os import os
import unittest import unittest
from torchvision import models, transforms, _C_tests from torchvision import models, transforms
import sys import sys
from PIL import Image from PIL import Image
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
try:
from torchvision import _C_tests
except ImportError:
_C_tests = None
def process_model(model, tensor, func, name): def process_model(model, tensor, func, name):
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment