setup.py 3.04 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4
5
6
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
7

Jun Ru Anderson's avatar
Jun Ru Anderson committed
8
import os
9
import re
Jun Ru Anderson's avatar
Jun Ru Anderson committed
10
11
import warnings

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
12
import setuptools
Jun Ru Anderson's avatar
Jun Ru Anderson committed
13
14
import torch
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
15

16
17
this_dir = os.path.dirname(os.path.abspath(__file__))

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
18
19
20
21
22
23
24

def fetch_requirements():
    with open("requirements.txt") as f:
        reqs = f.read().strip().split("\n")
    return reqs


25
26
27
28
29
30
31
32
33
# https://packaging.python.org/guides/single-sourcing-package-version/
def find_version(version_file_path):
    with open(version_file_path) as version_file:
        version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file.read(), re.M)
        if version_match:
            return version_match.group(1)
        raise RuntimeError("Unable to find version string.")


Jun Ru Anderson's avatar
Jun Ru Anderson committed
34
35
36
37
38
39
40
41
42
extensions = []
cmdclass = {}

force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
    extensions.extend(
        [
            CUDAExtension(
                name="fairscale.fused_adam_cuda",
43
                include_dirs=[os.path.join(this_dir, "fairscale/clib/fused_adam_cuda")],
Jun Ru Anderson's avatar
Jun Ru Anderson committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
                sources=[
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp",
                    "fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu",
                ],
                extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]},
            )
        ]
    )

    cmdclass["build_ext"] = BuildExtension
else:
    warnings.warn("Cannot install FusedAdam cuda.")


Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
58
59
60
if __name__ == "__main__":
    setuptools.setup(
        name="fairscale",
61
        description="FairScale: A PyTorch library for large-scale and high-performance training.",
62
        version=find_version("fairscale/__init__.py"),
Myle Ott's avatar
Myle Ott committed
63
        setup_requires=["ninja"],  # ninja is required to build extensions
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
64
65
66
        install_requires=fetch_requirements(),
        include_package_data=True,
        packages=setuptools.find_packages(exclude=("tests", "tests.*")),
Jun Ru Anderson's avatar
Jun Ru Anderson committed
67
68
        ext_modules=extensions,
        cmdclass=cmdclass,
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
69
70
71
        python_requires=">=3.6",
        author="Facebook AI Research",
        author_email="todo@fb.com",
72
73
        long_description="FairScale is a PyTorch extension library for high performance and large scale training on one or multiple machines/nodes. This library extends basic PyTorch capabilities while adding new experimental ones.",
        long_description_content_type="text/markdown",
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
74
75
76
        classifiers=[
            "Programming Language :: Python :: 3.7",
            "Programming Language :: Python :: 3.8",
77
            "Programming Language :: Python :: 3.9",
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
78
79
80
81
82
            "License :: OSI Approved :: BSD License",
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
            "Operating System :: OS Independent",
        ],
    )
Min Xu's avatar
Min Xu committed
83
84
85
86


# Bump this number if you want to force a CI cache invalidation on the pip venv.
# CI cache version: 1