setup.py 2.88 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4
5
6
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
7

Jun Ru Anderson's avatar
Jun Ru Anderson committed
8
import os
9
import re
Jun Ru Anderson's avatar
Jun Ru Anderson committed
10

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
12
import setuptools

13
14
this_dir = os.path.dirname(os.path.abspath(__file__))

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
15
16
17
18
19
20
21

def fetch_requirements():
    with open("requirements.txt") as f:
        reqs = f.read().strip().split("\n")
    return reqs


22
23
24
25
26
27
28
29
30
# https://packaging.python.org/guides/single-sourcing-package-version/
def find_version(version_file_path):
    with open(version_file_path) as version_file:
        version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file.read(), re.M)
        if version_match:
            return version_match.group(1)
        raise RuntimeError("Unable to find version string.")


Jun Ru Anderson's avatar
Jun Ru Anderson committed
31
32
33
extensions = []
cmdclass = {}

34
35
36
if os.getenv("BUILD_CUDA_EXTENSIONS", "0") == "1":
    from torch.utils.cpp_extension import BuildExtension, CUDAExtension

Jun Ru Anderson's avatar
Jun Ru Anderson committed
37
38
39
40
    extensions.extend(
        [
            CUDAExtension(
                name="fairscale.fused_adam_cuda",
41
                include_dirs=[os.path.join(this_dir, "fairscale/clib/fused_adam_cuda")],
Jun Ru Anderson's avatar
Jun Ru Anderson committed
42
43
44
45
46
47
48
49
50
51
52
53
                sources=[
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp",
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu",
                ],
                extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]},
            )
        ]
    )

    cmdclass["build_ext"] = BuildExtension


Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
54
55
56
if __name__ == "__main__":
    setuptools.setup(
        name="fairscale",
57
        description="FairScale: A PyTorch library for large-scale and high-performance training.",
58
        version=find_version("fairscale/__init__.py"),
Myle Ott's avatar
Myle Ott committed
59
        setup_requires=["ninja"],  # ninja is required to build extensions
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
60
61
62
        install_requires=fetch_requirements(),
        include_package_data=True,
        packages=setuptools.find_packages(exclude=("tests", "tests.*")),
Jun Ru Anderson's avatar
Jun Ru Anderson committed
63
64
        ext_modules=extensions,
        cmdclass=cmdclass,
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
65
66
67
        python_requires=">=3.6",
        author="Facebook AI Research",
        author_email="todo@fb.com",
68
69
        long_description="FairScale is a PyTorch extension library for high performance and large scale training on one or multiple machines/nodes. This library extends basic PyTorch capabilities while adding new experimental ones.",
        long_description_content_type="text/markdown",
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
70
71
72
        classifiers=[
            "Programming Language :: Python :: 3.7",
            "Programming Language :: Python :: 3.8",
73
            "Programming Language :: Python :: 3.9",
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
74
75
76
77
78
            "License :: OSI Approved :: BSD License",
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
            "Operating System :: OS Independent",
        ],
    )
Min Xu's avatar
Min Xu committed
79
80
81


# Bump this number if you want to force a CI cache invalidation on the pip venv.
82
# CI cache version: 3