setup.py 13.4 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

import atexit
import os
import sys
import subprocess
import io
import re
import copy
import tempfile
13
from packaging.version import Version
Przemek Tredak's avatar
Przemek Tredak committed
14
15
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
16
from shutil import copyfile
Przemek Tredak's avatar
Przemek Tredak committed
17
18


19
path = os.path.dirname(os.path.realpath(__file__))
Przemek Tredak's avatar
Przemek Tredak committed
20
21
with open(path + "/VERSION", "r") as f:
    te_version = f.readline()
22

23
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
24
25
26
27
NVTE_WITH_USERBUFFERS = int(os.environ.get("NVTE_WITH_USERBUFFERS", "0"))
if NVTE_WITH_USERBUFFERS:
    MPI_HOME = os.environ.get("MPI_HOME", "")
    assert MPI_HOME, "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
Przemek Tredak's avatar
Przemek Tredak committed
28
29
30
31
32
33
34
35
36
37

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output(
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
    )
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]
38
    return (int(bare_metal_major), int(bare_metal_minor))
Przemek Tredak's avatar
Przemek Tredak committed
39
40
41


def append_nvcc_threads(nvcc_extra_args):
42
43
    cuda_major, cuda_minor = get_cuda_bare_metal_version(CUDA_HOME)
    if cuda_major >= 11 and cuda_minor >= 2:
Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
47
48
        return nvcc_extra_args + ["--threads", "4"]
    return nvcc_extra_args


def extra_gencodes(cc_flag):
49
50
    cuda_bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
    if cuda_bare_metal_version >= (11, 0):
Przemek Tredak's avatar
Przemek Tredak committed
51
52
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_80,code=sm_80")
53
54
55
    if cuda_bare_metal_version >= (11, 8):
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_90,code=sm_90")
Przemek Tredak's avatar
Przemek Tredak committed
56
57
58


def extra_compiler_flags():
59
    extra_flags = [
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        "-O3",
        "-gencode",
        "arch=compute_70,code=sm_70",
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
        "-I./transformer_engine/common/layer_norm/",
        "--expt-relaxed-constexpr",
        "--expt-extended-lambda",
        "--use_fast_math",
    ]
74
75
    if NVTE_WITH_USERBUFFERS:
        extra_flags.append("-DNVTE_WITH_USERBUFFERS")
76
    return extra_flags
Przemek Tredak's avatar
Przemek Tredak committed
77
78
79
80
81
82
83
84
85
86
87
88
89


cc_flag = []
extra_gencodes(cc_flag)


def make_abs_path(l):
    return [os.path.join(path, p) for p in l]


pytorch_sources = [
    "transformer_engine/pytorch/csrc/extensions.cu",
    "transformer_engine/pytorch/csrc/common.cu",
90
    "transformer_engine/pytorch/csrc/ts_fp8_op.cpp",
Przemek Tredak's avatar
Przemek Tredak committed
91
92
93
94
95
96
97
98
]
pytorch_sources = make_abs_path(pytorch_sources)

all_sources = pytorch_sources

supported_frameworks = {
    "all": all_sources,
    "pytorch": pytorch_sources,
99
    "jax": None, # JAX use transformer_engine/CMakeLists.txt
100
    "tensorflow": None, # tensorflow use transformer_engine/CMakeLists.txt
Przemek Tredak's avatar
Przemek Tredak committed
101
102
}

103
framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
Przemek Tredak's avatar
Przemek Tredak committed
104

105
106
107
include_dirs = [
    "transformer_engine/common/include",
    "transformer_engine/pytorch/csrc",
cyanguwa's avatar
cyanguwa committed
108
    "3rdparty/cudnn-frontend/include",
109
]
110
111
112
if NVTE_WITH_USERBUFFERS:
    if MPI_HOME:
        include_dirs.append(os.path.join(MPI_HOME, "include"))
113
114
include_dirs = make_abs_path(include_dirs)

Przemek Tredak's avatar
Przemek Tredak committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
args = sys.argv.copy()
for s in args:
    if s.startswith("--framework="):
        framework = s.replace("--framework=", "")
        sys.argv.remove(s)
if framework not in supported_frameworks.keys():
    raise ValueError("Unsupported framework " + framework)


class CMakeExtension(Extension):
    def __init__(self, name, cmake_path, sources, **kwargs):
        super(CMakeExtension, self).__init__(name, sources=sources, **kwargs)
        self.cmake_path = cmake_path

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class FrameworkBuilderBase:
    def __init__(self, *args, **kwargs) -> None:
        pass

    def cmake_flags(self):
        return []

    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def run(self, extensions):
        pass

    @staticmethod
    def install_requires():
        return []

class PyTorchBuilder(FrameworkBuilderBase):
    def __init__(self, *args, **kwargs) -> None:
        pytorch_args = copy.deepcopy(args)
        pytorch_kwargs = copy.deepcopy(kwargs)
        from torch.utils.cpp_extension import BuildExtension
        self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)

    def initialize_options(self):
        self.pytorch_build_extensions.initialize_options()

    def finalize_options(self):
        self.pytorch_build_extensions.finalize_options()

    def run(self, extensions):
        other_ext = [
            ext for ext in extensions if not isinstance(ext, CMakeExtension)
        ]
        self.pytorch_build_extensions.extensions = other_ext
        print("Building pyTorch extensions!")
        self.pytorch_build_extensions.run()

170
    def cmake_flags(self):
171
        return []
172

173
174
    @staticmethod
    def install_requires():
175
        return ["flash-attn>=1.0.2", "packaging"]
176

177

178
class TensorFlowBuilder(FrameworkBuilderBase):
179
    def cmake_flags(self):
180
181
        p = [d for d in sys.path if 'dist-packages' in d][0]
        return ["-DENABLE_TENSORFLOW=ON", "-DCMAKE_PREFIX_PATH="+p]
182
183

    def run(self, extensions):
184
        print("Building TensorFlow extensions!")
Przemek Tredak's avatar
Przemek Tredak committed
185

186

187
class JaxBuilder(FrameworkBuilderBase):
188
    def cmake_flags(self):
189
190
        p = [d for d in sys.path if 'dist-packages' in d][0]
        return ["-DENABLE_JAX=ON", "-DCMAKE_PREFIX_PATH="+p]
191
192

    def run(self, extensions):
193
        print("Building jax extensions!")
194
195

    def install_requires():
196
197
        # TODO: find a way to install pybind11 and ninja directly.
        return ['cmake', 'flax']
198

Przemek Tredak's avatar
Przemek Tredak committed
199
ext_modules = []
200
dlfw_builder_funcs = []
Przemek Tredak's avatar
Przemek Tredak committed
201
202
203
204

ext_modules.append(
    CMakeExtension(
        name="transformer_engine",
205
        cmake_path=os.path.join(path, "transformer_engine"),
Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
        sources=[],
        include_dirs=include_dirs,
    )
)

if framework in ("all", "pytorch"):
212
    from torch.utils.cpp_extension import CUDAExtension
Przemek Tredak's avatar
Przemek Tredak committed
213
214
215
216
217
218
219
220
221
222
223
    ext_modules.append(
        CUDAExtension(
            name="transformer_engine_extensions",
            sources=supported_frameworks[framework],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
            },
            include_dirs=include_dirs,
        )
    )
224
225
226
227
    dlfw_builder_funcs.append(PyTorchBuilder)

if framework in ("all", "jax"):
    dlfw_builder_funcs.append(JaxBuilder)
228
229
230
231
232
    # Trigger a better error when pybind11 isn't present.
    # Sadly, if pybind11 was installed with `apt -y install pybind11-dev`
    # This doesn't install a python packages. So the line bellow is too strict.
    # When it fail, we need to detect if cmake will find pybind11.
    # import pybind11
233

234
235
236
if framework in ("all", "tensorflow"):
    dlfw_builder_funcs.append(TensorFlowBuilder)

237
dlfw_install_requires = ['pydantic']
238
239
for builder in dlfw_builder_funcs:
    dlfw_install_requires = dlfw_install_requires + builder.install_requires()
Przemek Tredak's avatar
Przemek Tredak committed
240
241
242
243
244
245
246


def get_cmake_bin():
    cmake_bin = "cmake"
    try:
        out = subprocess.check_output([cmake_bin, "--version"])
    except OSError:
247
        cmake_installed_version = Version("0.0")
Przemek Tredak's avatar
Przemek Tredak committed
248
    else:
249
        cmake_installed_version = Version(
Przemek Tredak's avatar
Przemek Tredak committed
250
251
252
            re.search(r"version\s*([\d.]+)", out.decode()).group(1)
        )

253
    if cmake_installed_version < Version("3.18.0"):
Przemek Tredak's avatar
Przemek Tredak committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        print(
            "Could not find a recent CMake to build Transformer Engine. "
            "Attempting to install CMake 3.18 to a temporary location via pip.",
            flush=True,
        )
        cmake_temp_dir = tempfile.TemporaryDirectory(prefix="nvte-cmake-tmp")
        atexit.register(cmake_temp_dir.cleanup)
        try:
            _ = subprocess.check_output(
                ["pip", "install", "--target", cmake_temp_dir.name, "cmake~=3.18.0"]
            )
        except Exception:
            raise RuntimeError(
                "Failed to install temporary CMake. "
                "Please update your CMake to 3.18+."
            )
        cmake_bin = os.path.join(cmake_temp_dir.name, "bin", "run_cmake")
        with io.open(cmake_bin, "w") as f_run_cmake:
            f_run_cmake.write(
                f"#!/bin/sh\nPYTHONPATH={cmake_temp_dir.name} {os.path.join(cmake_temp_dir.name, 'bin', 'cmake')} \"$@\""
            )
        os.chmod(cmake_bin, 0o755)

    return cmake_bin


class CMakeBuildExtension(build_ext, object):
    def __init__(self, *args, **kwargs) -> None:
282
        self.dlfw_flags = kwargs["dlfw_flags"]
Przemek Tredak's avatar
Przemek Tredak committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        super(CMakeBuildExtension, self).__init__(*args, **kwargs)

    def build_extensions(self) -> None:
        print("Building CMake extensions!")

        cmake_bin = get_cmake_bin()
        config = "Debug" if self.debug else "Release"

        ext_name = self.extensions[0].name
        build_dir = self.get_ext_fullpath(ext_name).replace(
            self.get_ext_filename(ext_name), ""
        )
        build_dir = os.path.abspath(build_dir)

        cmake_args = [
            "-DCMAKE_BUILD_TYPE=" + config,
            "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
        ]
301
302
303
304
305
306
307
        try:
            import ninja
        except ImportError:
            pass
        else:
            cmake_args.append("-GNinja")

308
        cmake_args = cmake_args + self.dlfw_flags
Przemek Tredak's avatar
Przemek Tredak committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

        cmake_build_args = ["--config", config]

        cmake_build_dir = os.path.join(self.build_temp, config)
        if not os.path.exists(cmake_build_dir):
            os.makedirs(cmake_build_dir)

        config_and_build_commands = [
            [cmake_bin, self.extensions[0].cmake_path] + cmake_args,
            [cmake_bin, "--build", "."] + cmake_build_args,
        ]

        if True:
            print(f"Running CMake in {cmake_build_dir}:")
            for command in config_and_build_commands:
                print(" ".join(command))
            sys.stdout.flush()

        # Config and build the extension
        try:
            for command in config_and_build_commands:
                subprocess.check_call(command, cwd=cmake_build_dir)
        except OSError as e:
            raise RuntimeError("CMake failed: {}".format(str(e)))

class TEBuildExtension(build_ext, object):
    def __init__(self, *args, **kwargs) -> None:
336
337
338
339
340
341

        self.dlfw_builder = []
        for functor in dlfw_builder_funcs:
            self.dlfw_builder.append(functor(*args, **kwargs))

        flags = []
342
343
        if NVTE_WITH_USERBUFFERS:
            flags.append('-DNVTE_WITH_USERBUFFERS=ON')
344
345
346
        for builder in self.dlfw_builder:
            flags = flags + builder.cmake_flags()

Przemek Tredak's avatar
Przemek Tredak committed
347
348
        cmake_args = copy.deepcopy(args)
        cmake_kwargs = copy.deepcopy(kwargs)
349
        cmake_kwargs["dlfw_flags"] = flags
Przemek Tredak's avatar
Przemek Tredak committed
350
        self.cmake_build_extensions = CMakeBuildExtension(*cmake_args, **cmake_kwargs)
351

Przemek Tredak's avatar
Przemek Tredak committed
352
353
354
355
356
        self.all_outputs = None
        super(TEBuildExtension, self).__init__(*args, **kwargs)

    def initialize_options(self):
        self.cmake_build_extensions.initialize_options()
357
358
        for builder in self.dlfw_builder:
            builder.initialize_options()
Przemek Tredak's avatar
Przemek Tredak committed
359
360
361
362
        super(TEBuildExtension, self).initialize_options()

    def finalize_options(self):
        self.cmake_build_extensions.finalize_options()
363
364
        for builder in self.dlfw_builder:
            builder.finalize_options()
Przemek Tredak's avatar
Przemek Tredak committed
365
366
367
368
369
370
371
        super(TEBuildExtension, self).finalize_options()

    def run(self) -> None:
        old_inplace, self.inplace = self.inplace, 0
        cmake_ext = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
        self.cmake_build_extensions.extensions = cmake_ext
        self.cmake_build_extensions.run()
372
373
374

        for builder in self.dlfw_builder:
            builder.run(self.extensions)
Przemek Tredak's avatar
Przemek Tredak committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        self.all_outputs = []
        for f in os.scandir(self.build_lib):
            if f.is_file():
                self.all_outputs.append(f.path)

        self.inplace = old_inplace
        if old_inplace:
            self.copy_extensions_to_source()

    def copy_extensions_to_source(self):
        ext = self.extensions[0]
        build_py = self.get_finalized_command("build_py")
        fullname = self.get_ext_fullname(ext.name)
        modpath = fullname.split(".")
        package = ".".join(modpath[:-1])
        package_dir = build_py.get_package_dir(package)

        for f in os.scandir(self.build_lib):
            if f.is_file():
                src_filename = f.path
                dest_filename = os.path.join(
                    package_dir, os.path.basename(src_filename)
                )
                # Always copy, even if source is older than destination, to ensure
                # that the right extensions for the current Python/platform are
                # used.
402
                copyfile(src_filename, dest_filename)
Przemek Tredak's avatar
Przemek Tredak committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

    def get_outputs(self):
        return self.all_outputs


setup(
    name="transformer_engine",
    version=te_version,
    packages=find_packages(
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "tests",
            "examples",
            "transformer_engine.egg-info",
        )
    ),
    description="Transformer acceleration library",
    ext_modules=ext_modules,
    cmdclass={"build_ext": TEBuildExtension},
427
    install_requires=dlfw_install_requires,
428
429
430
431
432
    extras_require={
        'test': ['pytest',
                 'tensorflow_datasets'],
        'test_pytest': ['onnxruntime',],
    },
Przemek Tredak's avatar
Przemek Tredak committed
433
434
    license_files=("LICENSE",),
)