setup.py 5.24 KB
Newer Older
1
from __future__ import print_function
soumith's avatar
soumith committed
2
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
3
4
import io
import re
soumith's avatar
soumith committed
5
6
import sys
from setuptools import setup, find_packages
7
from pkg_resources import get_distribution, DistributionNotFound
8
import subprocess
9
10
11
12
13
14
import distutils.command.clean
import glob
import shutil

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
15
16


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
24

25
26
27
28
29
30
31
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


Soumith Chintala's avatar
Soumith Chintala committed
32
version = '0.3.0a0'
33
34
sha = 'Unknown'
package_name = os.getenv('TORCHVISION_PACKAGE_NAME', 'torchvision')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

if os.getenv('TORCHVISION_BUILD_VERSION'):
    assert os.getenv('TORCHVISION_BUILD_NUMBER') is not None
    build_number = int(os.getenv('TORCHVISION_BUILD_NUMBER'))
    version = os.getenv('TORCHVISION_BUILD_VERSION')
    if build_number > 1:
        version += '.post' + str(build_number)
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
59
60
61
        f.write("from torchvision import _C\n")
        f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
        f.write("    cuda = _C.CUDA_VERSION\n")
62
63
64


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
65

Thomas Grainger's avatar
Thomas Grainger committed
66
67
readme = open('README.rst').read()

68
pytorch_package_name = os.getenv('TORCHVISION_PYTORCH_DEPENDENCY_NAME', 'torch')
soumith's avatar
soumith committed
69

70
71
72
requirements = [
    'numpy',
    'six',
73
    pytorch_package_name,
74
75
]

76
77
78
79
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

80

81
82
83
84
85
86
87
88
89
90
91
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
    source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

    sources = main_file + source_cpu
    extension = CppExtension

Shahriar's avatar
Shahriar committed
92
93
94
95
96
97
98
99
100
    test_dir = os.path.join(this_dir, 'test')
    models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
    test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
    source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

    test_file = [os.path.join(test_dir, s) for s in test_file]
    source_models = [os.path.join(models_dir, s) for s in source_models]
    tests = test_file + source_models

101
102
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
103
    extra_compile_args = {}
104
    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
105
106
107
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [('WITH_CUDA', None)]
Soumith Chintala's avatar
Soumith Chintala committed
108
109
110
111
112
113
114
115
116
        nvcc_flags = os.getenv('NVCC_FLAGS', '')
        if nvcc_flags == '':
            nvcc_flags = []
        else:
            nvcc_flags = nvcc_flags.split(' ')
        extra_compile_args = {
            'cxx': ['-O0'],
            'nvcc': nvcc_flags,
        }
117

118
119
120
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

121
122
123
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]
Shahriar's avatar
Shahriar committed
124
    tests_include_dirs = [test_dir, models_dir]
125
126
127
128
129
130
131

    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
132
            extra_compile_args=extra_compile_args,
Shahriar's avatar
Shahriar committed
133
134
135
136
137
138
139
        ),
        extension(
            'torchvision._C_tests',
            tests,
            include_dirs=tests_include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        )
    ]

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
161
setup(
soumith's avatar
soumith committed
162
    # Metadata
163
164
    name=package_name,
    version=version,
soumith's avatar
soumith committed
165
166
167
168
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
169
    long_description=readme,
soumith's avatar
soumith committed
170
171
172
    license='BSD',

    # Package info
soumith's avatar
soumith committed
173
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
174
175

    zip_safe=True,
176
    install_requires=requirements,
177
178
179
    extras_require={
        "scipy": ["scipy"],
    },
180
181
    ext_modules=get_extensions(),
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
soumith's avatar
soumith committed
182
)