setup.py 3.78 KB
Newer Older
pkufool's avatar
pkufool committed
1
2
3
4
5
#!/usr/bin/env python3
#
# Copyright (c)  2022  Xiaomi Corporation (author: Wei Kang)

import glob
anton's avatar
anton committed
6
import os
pkufool's avatar
pkufool committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import re
import shutil
import sys

import setuptools
from setuptools.command.build_ext import build_ext

cur_dir = os.path.dirname(os.path.abspath(__file__))


def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
    kwargs["language"] = "c++"
    sources = []
    return setuptools.Extension(name, sources, *args, **kwargs)
21

anton's avatar
anton committed
22

pkufool's avatar
pkufool committed
23
24
25
26
27
class BuildExtension(build_ext):
    def build_extension(self, ext: setuptools.extension.Extension):
        # build/temp.linux-x86_64-3.8
        build_dir = self.build_temp
        os.makedirs(build_dir, exist_ok=True)
anton's avatar
anton committed
28

pkufool's avatar
pkufool committed
29
30
        # build/lib.linux-x86_64-3.8
        os.makedirs(self.build_lib, exist_ok=True)
anton's avatar
anton committed
31

pkufool's avatar
pkufool committed
32
        ft_dir = os.path.dirname(os.path.abspath(__file__))
anton's avatar
anton committed
33

pkufool's avatar
pkufool committed
34
35
36
        cmake_args = os.environ.get("FT_CMAKE_ARGS", "")
        make_args = os.environ.get("FT_MAKE_ARGS", "")
        system_make_args = os.environ.get("MAKEFLAGS", "")
anton's avatar
anton committed
37

pkufool's avatar
pkufool committed
38
        if cmake_args == "":
pkufool's avatar
pkufool committed
39
            cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"
anton's avatar
anton committed
40

pkufool's avatar
pkufool committed
41
        if make_args == "" and system_make_args == "":
Wei Kang's avatar
Wei Kang committed
42
            make_args = " -j "
anton's avatar
anton committed
43

pkufool's avatar
pkufool committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        if "PYTHON_EXECUTABLE" not in cmake_args:
            print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
            cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"

        build_cmd = f"""
            cd {self.build_temp}

            cmake {cmake_args} {ft_dir}

            make {make_args} _fast_rnnt
        """
        print(f"build command is:\n{build_cmd}")

        ret = os.system(build_cmd)
        if ret != 0:
            raise Exception(
                "\nBuild fast_rnnt failed. Please check the error "
                "message.\n"
                "You can ask for help by creating an issue on GitHub.\n"
                "\nClick:\n"
                "\thttps://github.com/danpovey/fast_rnnt/issues/new\n"  # noqa
65
            )
pkufool's avatar
pkufool committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        lib_so = glob.glob(f"{build_dir}/lib/*.so*")
        for so in lib_so:
            print(f"Copying {so} to {self.build_lib}/")
            shutil.copy(f"{so}", f"{self.build_lib}/")

        # macos
        lib_so = glob.glob(f"{build_dir}/lib/*.dylib*")
        for so in lib_so:
            print(f"Copying {so} to {self.build_lib}/")
            shutil.copy(f"{so}", f"{self.build_lib}/")


def read_long_description():
    with open("README.md", encoding="utf8") as f:
        readme = f.read()
    return readme


def get_package_version():
    with open("CMakeLists.txt") as f:
        content = f.read()

    latest_version = re.search(r"set\(FT_VERSION (.*)\)", content).group(1)
    latest_version = latest_version.strip('"')
    return latest_version

Wei Kang's avatar
Wei Kang committed
92

pkufool's avatar
pkufool committed
93
94
95
96
97
98
def get_requirements():
    with open("requirements.txt", encoding="utf8") as f:
        requirements = f.read().splitlines()

    return requirements

Wei Kang's avatar
Wei Kang committed
99

pkufool's avatar
pkufool committed
100
101
package_name = "fast_rnnt"

Wei Kang's avatar
Wei Kang committed
102
with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f:
pkufool's avatar
pkufool committed
103
104
105
106
107
108
109
110
111
    f.write(f"__version__ = '{get_package_version()}'\n")

setuptools.setup(
    name=package_name,
    version=get_package_version(),
    author="Dan Povey",
    author_email="dpovey@gmail.com",
    package_dir={
        package_name: "fast_rnnt/python/fast_rnnt",
anton's avatar
anton committed
112
    },
pkufool's avatar
pkufool committed
113
114
    packages=[package_name],
    url="https://github.com/danpovey/fast_rnnt",
pkufool's avatar
pkufool committed
115
    description="Fast and memory-efficient RNN-T loss.",
pkufool's avatar
pkufool committed
116
117
118
119
120
121
122
123
124
125
    long_description=read_long_description(),
    long_description_content_type="text/markdown",
    install_requires=get_requirements(),
    ext_modules=[cmake_extension("_fast_rnnt")],
    cmdclass={"build_ext": BuildExtension},
    zip_safe=False,
    classifiers=[
        "Programming Language :: C++",
        "Programming Language :: Python",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
anton's avatar
anton committed
126
    ],
pkufool's avatar
pkufool committed
127
    license="Apache licensed, as found in the LICENSE file",
anton's avatar
anton committed
128
)