setup.py 8.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""The setuptools based setup module.

Reference:
    https://packaging.python.org/guides/distributing-packages-using-setuptools/
"""

import os
import sys
import pathlib
13
from typing import List, Tuple, ClassVar
14
15
16
17
18

from setuptools import setup, find_packages, Command

import superbench

19
20
21
22
23
24
25
26
27
28
29
print(f'Python {sys.version_info.major}.{sys.version_info.minor} detected.')
if sys.version_info[:2] < (3, 11):
    import pkg_resources
    try:
        pkg_resources.require(['pip>=18', 'setuptools>=45, <66'])
    except (pkg_resources.VersionConflict, pkg_resources.DistributionNotFound):
        print(
            '\033[93mTry update pip/setuptools versions, for example, '
            'python3 -m pip install --upgrade pip wheel setuptools==65.7\033[0m'
        )
        raise
30

31
32
33
34
35
36
37
38
39
40
41
42
43
here = pathlib.Path(__file__).parent.resolve()
long_description = (here / 'README.md').read_text(encoding='utf-8')


class Formatter(Command):
    """Cmdclass for `python setup.py format`.

    Args:
        Command (distutils.cmd.Command):
            Abstract base class for defining command classes.
    """

    description = 'format the code using yapf'
44
    user_options: ClassVar[List[Tuple[str, str, str]]] = []
45
46
47
48
49
50
51
52
53
54

    def initialize_options(self):
        """Set default values for options that this command supports."""
        pass

    def finalize_options(self):
        """Set final values for options that this command supports."""
        pass

    def run(self):
55
56
57
58
59
60
61
        """Format the code using yapf."""
        if sys.version_info[:2] >= (3, 12):
            # TODO: Remove this block when yapf is compatible with Python 3.12+.
            print('Disable yapf for Python 3.12+ due to the compatibility issue.')
        else:
            errno = os.system('python3 -m yapf --in-place --recursive --exclude .git --exclude .eggs .')
            sys.exit(0 if errno == 0 else 1)
62
63
64
65
66
67
68
69
70
71
72


class Linter(Command):
    """Cmdclass for `python setup.py lint`.

    Args:
        Command (distutils.cmd.Command):
            Abstract base class for defining command classes.
    """

    description = 'lint the code using flake8'
73
    user_options: ClassVar[List[Tuple[str, str, str]]] = []
74
75
76
77
78
79
80
81
82
83
84

    def initialize_options(self):
        """Set default values for options that this command supports."""
        pass

    def finalize_options(self):
        """Set final values for options that this command supports."""
        pass

    def run(self):
        """Lint the code with yapf, mypy, and flake8."""
85
86
87
        if sys.version_info[:2] >= (3, 12):
            # TODO: Remove this block when yapf is compatible with Python 3.12+.
            print('Disable lint for Python 3.12+ due to the compatibility issue.')
88
89
90
        errno = os.system(
            ' && '.join(
                [
91
92
                    'python3 -m yapf --diff --recursive --exclude .git --exclude .eggs .' if sys.version_info[:2] <
                    (3, 12) else ':',
93
94
95
96
97
98
                    'python3 -m mypy .',
                    'python3 -m flake8',
                ]
            )
        )
        sys.exit(0 if errno == 0 else 1)
99
100
101
102
103
104
105
106
107
108
109


class Tester(Command):
    """Cmdclass for `python setup.py test`.

    Args:
        Command (distutils.cmd.Command):
            Abstract base class for defining command classes.
    """

    description = 'test the code using pytest'
110
    user_options: ClassVar[List[Tuple[str, str, str]]] = []
111
112
113
114
115
116
117
118
119
120
121

    def initialize_options(self):
        """Set default values for options that this command supports."""
        pass

    def finalize_options(self):
        """Set final values for options that this command supports."""
        pass

    def run(self):
        """Run pytest."""
122
        errno = os.system('python3 -m pytest -v --cov=superbench --cov-report=xml --cov-report=term-missing tests/')
123
        sys.exit(0 if errno == 0 else 1)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146


setup(
    name='superbench',
    version=superbench.__version__,
    description='Provide hardware and software benchmarks for AI systems.',
    long_description=long_description,
    long_description_content_type='text/markdown',
    url='https://github.com/microsoft/superbenchmark',
    author=superbench.__author__,
    author_email='superbench@microsoft.com',
    license='MIT',
    classifiers=[
        'Development Status :: 2 - Pre-Alpha',
        'Environment :: GPU',
        'Intended Audience :: System Administrators',
        'License :: OSI Approved :: MIT License',
        'Operating System :: POSIX',
        'Programming Language :: Python :: 3',
        'Programming Language :: Python :: 3 :: Only',
        'Programming Language :: Python :: 3.7',
        'Programming Language :: Python :: 3.8',
        'Programming Language :: Python :: 3.9',
147
        'Programming Language :: Python :: 3.10',
148
149
150
        'Programming Language :: Python :: 3.11',
        'Programming Language :: Python :: 3.12',
        'Programming Language :: Python :: 3.13',
151
152
153
154
155
156
        'Topic :: System :: Benchmark',
        'Topic :: System :: Clustering',
        'Topic :: System :: Hardware',
    ],
    keywords='benchmark, AI systems',
    packages=find_packages(exclude=['tests']),
157
    python_requires='>=3.7, <4',
158
159
160
161
162
163
164
165
    use_scm_version={
        'local_scheme': 'node-and-date',
        'version_scheme': lambda _: superbench.__version__,
        'fallback_version': f'{superbench.__version__}+unknown',
    },
    setup_requires=[
        'setuptools_scm',
    ],
166
    install_requires=[
167
168
169
170
        'ansible;os_name=="posix" and python_version>"3.10"',
        'ansible_base>=2.10.9;os_name=="posix" and python_version<="3.10"',
        'ansible_runner>=2.0.0rc1, <2.3.2;python_version<="3.10"',
        'ansible_runner;python_version>"3.10"',
171
        'colorlog>=6.7.0',
172
        'importlib_metadata',
173
        'jinja2>=2.10.1',
174
        'joblib>=1.0.1',
175
        'jsonlines>=2.0.0',
176
        'knack>=0.7.2',
177
        'markdown>=3.3.0',
178
        'matplotlib>=3.0.0',
179
        'natsort>=7.1.1',
180
        'networkx>=2.5',
181
        'numpy>=1.19.2',
182
        'omegaconf==2.3.0',
183
        'openpyxl>=3.0.7',
184
        'packaging>=21.0',
185
        'pandas>=1.1.5',
186
        'protobuf',
187
        'pssh @ git+https://github.com/lilydjwg/pssh.git@v2.3.4',
188
        'pyyaml>=5.3',
189
        'requests>=2.27.1',
190
        'seaborn>=0.11.2',
191
        'tcping>=0.1.1rc1',
192
        'urllib3>=1.26.9',
193
194
        'xlrd>=2.0.1',
        'xlsxwriter>=1.3.8',
195
        'xmltodict>=0.12.0',
196
        'types-requests',
197
    ],
198
199
200
201
202
    extras_require=(
        lambda x: {
            **x,
            'develop': x['dev'] + x['test'],
            'cpuworker': x['torch'],
203
            'amdworker': x['torch'] + x['amd'],
204
            'nvworker': x['torch'] + x['ort'] + x['nvidia'],
one's avatar
one committed
205
            'hgworker': x['amd'],
206
207
208
209
210
211
212
        }
    )(
        {
            'dev': ['pre-commit>=2.10.0'],
            'test': [
                'flake8-docstrings>=1.5.0',
                'flake8-quotes>=3.2.0',
213
                'flake8>=3.8.4',
214
215
216
217
                'mypy>=0.800',
                'pydocstyle>=5.1.1',
                'pytest-cov>=2.11.1',
                'pytest-subtests>=0.4.0',
218
                'pytest>=6.2.2, <=7.4.4',
219
                'types-markdown',
220
                'types-setuptools',
221
                'types-pyyaml',
222
                'typing-extensions>=3.10',
223
                'urllib3<2.0',
224
225
226
227
                'vcrpy>=4.1.1',
                'yapf==0.31.0',
            ],
            'torch': [
228
229
230
231
                'safetensors==0.4.5; python_version<"3.12"',
                'safetensors>=0.5.3; python_version>="3.12"',
                'tokenizers<=0.20.3; python_version<"3.12"',
                'tokenizers<0.22; python_version>="3.12"',
232
233
                'torch>=1.7.0a0',
                'torchvision>=0.8.0a0',
234
235
                'transformers>=4.28.0; python_version<"3.12"',
                'transformers==4.52.4; python_version>="3.12"',
236
237
238
            ],
            'ort': [
                'onnx>=1.10.2',
239
                'onnxruntime-gpu==1.12.0; python_version<"3.10" and platform_machine == "x86_64"',
240
                'onnxruntime-gpu; python_version>="3.10" and platform_machine == "x86_64"',
241
242
            ],
            'nvidia': ['py3nvml>=0.2.6'],
243
            'amd': ['amdsmi'],
244
245
        }
    ),
246
    include_package_data=True,
247
    entry_points={
248
249
250
        'console_scripts': [
            'sb = superbench.cli.sb:main',
        ],
251
252
253
254
255
256
    },
    cmdclass={
        'format': Formatter,
        'lint': Linter,
        'test': Tester,
    },
Yifan Xiong's avatar
Yifan Xiong committed
257
258
259
260
    project_urls={
        'Source': 'https://github.com/microsoft/superbenchmark',
        'Tracker': 'https://github.com/microsoft/superbenchmark/issues',
    },
261
)