setup.py 5.66 KB
Newer Older
1
from __future__ import print_function
soumith's avatar
soumith committed
2
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
3
4
import io
import re
soumith's avatar
soumith committed
5
6
import sys
from setuptools import setup, find_packages
7
from pkg_resources import get_distribution, DistributionNotFound
8
import subprocess
9
10
11
12
13
14
import distutils.command.clean
import glob
import shutil

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
15
16


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
24

25
26
27
28
29
30
31
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


32
version = '0.4.0a0'
33
34
sha = 'Unknown'
package_name = os.getenv('TORCHVISION_PACKAGE_NAME', 'torchvision')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
35

36
37
38
39
40
41
42
43
44
45
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

if os.getenv('TORCHVISION_BUILD_VERSION'):
    assert os.getenv('TORCHVISION_BUILD_NUMBER') is not None
    build_number = int(os.getenv('TORCHVISION_BUILD_NUMBER'))
46
47
    base_version = os.getenv('TORCHVISION_BUILD_VERSION')
    version = base_version
48
49
    if build_number > 1:
        version += '.post' + str(build_number)
50
51
52
    local_label = os.getenv('TORCHVISION_LOCAL_VERSION_LABEL')
    if local_label is not None:
        version += '+' + local_label
53
54
55
56
57
58
59
60
61
62
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
63
64
65
        f.write("from torchvision import _C\n")
        f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
        f.write("    cuda = _C.CUDA_VERSION\n")
66
67
68


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
69

Thomas Grainger's avatar
Thomas Grainger committed
70
71
readme = open('README.rst').read()

72
73
74
75
76
77
pytorch_dep = os.getenv('TORCHVISION_PYTORCH_DEPENDENCY_NAME', 'torch')
if os.getenv('TORCHVISION_PYTORCH_DEPENDENCY_VERSION'):
    pytorch_dep += "==" + os.getenv('TORCHVISION_PYTORCH_DEPENDENCY_VERSION')
    # torchvision has CUDA bits, thus, we must specify a local dependency
    if local_label is not None:
        pytorch_dep += '+' + local_label
soumith's avatar
soumith committed
78

79
80
81
requirements = [
    'numpy',
    'six',
82
    pytorch_dep,
83
84
]

85
86
87
88
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

89

90
91
92
93
94
95
96
97
98
99
100
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
    source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

    sources = main_file + source_cpu
    extension = CppExtension

Shahriar's avatar
Shahriar committed
101
102
103
104
105
106
107
108
109
    test_dir = os.path.join(this_dir, 'test')
    models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
    test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
    source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

    test_file = [os.path.join(test_dir, s) for s in test_file]
    source_models = [os.path.join(models_dir, s) for s in source_models]
    tests = test_file + source_models

110
111
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
112
    extra_compile_args = {}
113
    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
114
115
116
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [('WITH_CUDA', None)]
Soumith Chintala's avatar
Soumith Chintala committed
117
118
119
120
121
122
123
124
125
        nvcc_flags = os.getenv('NVCC_FLAGS', '')
        if nvcc_flags == '':
            nvcc_flags = []
        else:
            nvcc_flags = nvcc_flags.split(' ')
        extra_compile_args = {
            'cxx': ['-O0'],
            'nvcc': nvcc_flags,
        }
126

127
128
129
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

130
131
132
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]
Shahriar's avatar
Shahriar committed
133
    tests_include_dirs = [test_dir, models_dir]
134
135
136
137
138
139
140

    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
141
            extra_compile_args=extra_compile_args,
Shahriar's avatar
Shahriar committed
142
143
144
145
146
147
148
        ),
        extension(
            'torchvision._C_tests',
            tests,
            include_dirs=tests_include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        )
    ]

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
170
setup(
soumith's avatar
soumith committed
171
    # Metadata
172
173
    name=package_name,
    version=version,
soumith's avatar
soumith committed
174
175
176
177
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
178
    long_description=readme,
soumith's avatar
soumith committed
179
180
181
    license='BSD',

    # Package info
soumith's avatar
soumith committed
182
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
183
184

    zip_safe=True,
185
    install_requires=requirements,
186
187
188
    extras_require={
        "scipy": ["scipy"],
    },
189
190
    ext_modules=get_extensions(),
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
soumith's avatar
soumith committed
191
)