setup.py 36.6 KB
Newer Older
1
import torch
Masaki Kozuki's avatar
Masaki Kozuki committed
2
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
3
from setuptools import setup, find_packages
mcarilli's avatar
mcarilli committed
4
import subprocess
5

jjsjann123's avatar
jjsjann123 committed
6
import sys
Marek Kolodziej's avatar
Marek Kolodziej committed
7
import warnings
mcarilli's avatar
mcarilli committed
8
import os
jjsjann123's avatar
jjsjann123 committed
9

10
11
12
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

ptrblck's avatar
ptrblck committed
13
14
15
16
17
18
19
20
21
22
def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor

Jithun Nair's avatar
Jithun Nair committed
23
24
25
26
print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

27
28
def check_if_rocm_pytorch():
    is_rocm_pytorch = False
Jithun Nair's avatar
Jithun Nair committed
29
    if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
30
31
32
33
34
35
36
37
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    return is_rocm_pytorch

IS_ROCM_PYTORCH = check_if_rocm_pytorch()

if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
mcarilli's avatar
mcarilli committed
38
39
40
41
42
43
    # https://github.com/NVIDIA/apex/issues/486
    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
ptrblck's avatar
ptrblck committed
44
45
          'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
          'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
mcarilli's avatar
mcarilli committed
46
47
48
          'If you wish to cross-compile for a single specific architecture,\n'
          'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
Masaki Kozuki's avatar
Masaki Kozuki committed
49
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
ptrblck's avatar
ptrblck committed
50
51
52
53
        if int(bare_metal_major) == 11:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
        else:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
54
55
56
57
58
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for the same gfx targets\n'
          'used by default in ROCm PyTorch\n')
59

60
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
Michael Carilli's avatar
Michael Carilli committed
61
      raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
62
63
                         "The latest stable release can be obtained from https://pytorch.org/")

jjsjann123's avatar
jjsjann123 committed
64
65
66
cmdclass = {}
ext_modules = []

ptrblck's avatar
ptrblck committed
67
extras = {}
Marek Kolodziej's avatar
Marek Kolodziej committed
68
if "--pyprof" in sys.argv:
69
70
71
72
73
    string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
             "soon be removed from Apex.  Please visit\n" + \
             "https://github.com/NVIDIA/PyProf\n" + \
             "for the latest version."
    warnings.warn(string, DeprecationWarning)
Marek Kolodziej's avatar
Marek Kolodziej committed
74
75
    with open('requirements.txt') as f:
        required_packages = f.read().splitlines()
ptrblck's avatar
ptrblck committed
76
        extras['pyprof'] = required_packages
Marek Kolodziej's avatar
Marek Kolodziej committed
77
78
79
80
81
82
83
    try:
        sys.argv.remove("--pyprof")
    except:
        pass
else:
    warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")

84
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
Michael Carilli's avatar
Michael Carilli committed
85
86
    if TORCH_MAJOR == 0:
        raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
87
                           "found torch.__version__ = {}".format(torch.__version__))
88
    cmdclass['build_ext'] = BuildExtension
89
90
91
92
93
94
if "--cpp_ext" in sys.argv:
    sys.argv.remove("--cpp_ext")
    ext_modules.append(
        CppExtension('apex_C',
                     ['csrc/flatten_unflatten.cpp',]))

ptrblck's avatar
ptrblck committed
95
def get_cuda_bare_metal_version(cuda_dir):
mcarilli's avatar
mcarilli committed
96
97
98
99
100
101
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]
ptrblck's avatar
ptrblck committed
102
103
104
105
106

    return raw_output, bare_metal_major, bare_metal_minor

def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
mcarilli's avatar
mcarilli committed
107
108
109
110
111
112
113
    torch_binary_major = torch.version.cuda.split(".")[0]
    torch_binary_minor = torch.version.cuda.split(".")[1]

    print("\nCompiling cuda extensions with")
    print(raw_output + "from " + cuda_dir + "/bin\n")

    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
Michael Carilli's avatar
Michael Carilli committed
114
115
116
117
118
119
        raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
                           "not match the version used to compile Pytorch binaries.  " +
                           "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
                           "In some cases, a minor-version mismatch will not cause later errors:  " +
                           "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
                           "You can try commenting out this check (at your own risk).")
mcarilli's avatar
mcarilli committed
120

mcarilli's avatar
mcarilli committed
121
122
123
124
125
126
127
128
129
130
131
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
    version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
    version_ge_1_3 = ['-DVERSION_GE_1_3']
132
133
134
135
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
    version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
mcarilli's avatar
mcarilli committed
136

137
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
138
    from torch.utils.cpp_extension import CUDAExtension
139
140
    if "--distributed_adam" in sys.argv:
        sys.argv.remove("--distributed_adam")
141
142
143
144

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

145
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
146
147
        raise RuntimeError("--distributed_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
148
149
        nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_adam = ['-O3'] + version_dependent_macros
150
151
152
153
        ext_modules.append(
            CUDAExtension(name='distributed_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
154
155
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
156
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
157
                                              'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
158

159
if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
160
    from torch.utils.cpp_extension import CUDAExtension
161
162
    if "--distributed_lamb" in sys.argv:
        sys.argv.remove("--distributed_lamb")
163
164
165
166

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

167
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
168
169
        raise RuntimeError("--distributed_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
170
171
172
        print ("INFO: Building the distributed_lamb extension.")
        nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros
173
174
175
176
177
178
        ext_modules.append(
            CUDAExtension(name='distributed_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
179
                                              'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
180

jjsjann123's avatar
jjsjann123 committed
181
if "--cuda_ext" in sys.argv:
182
    from torch.utils.cpp_extension import CUDAExtension
183

184
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
Michael Carilli's avatar
Michael Carilli committed
185
        raise RuntimeError("--cuda_ext was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
186
    else:
187
        if not IS_ROCM_PYTORCH:
188
189
            check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)

190
191
192
193
        print ("INFO: Building the multi-tensor apply extension.")
        nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
        ext_modules.append(
194
195
196
197
198
199
            CUDAExtension(name='amp_C',
                          sources=['csrc/amp_C_frontend.cpp',
                                   'csrc/multi_tensor_sgd_kernel.cu',
                                   'csrc/multi_tensor_scale_kernel.cu',
                                   'csrc/multi_tensor_axpby_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu',
200
                                   'csrc/multi_tensor_l2norm_kernel_mp.cu',
201
                                   'csrc/multi_tensor_l2norm_scale_kernel.cu',
202
203
204
205
206
                                   'csrc/multi_tensor_lamb_stage_1.cu',
                                   'csrc/multi_tensor_lamb_stage_2.cu',
                                   'csrc/multi_tensor_adam.cu',
                                   'csrc/multi_tensor_adagrad.cu',
                                   'csrc/multi_tensor_novograd.cu',
207
208
                                   'csrc/multi_tensor_lamb.cu',
                                   'csrc/multi_tensor_lamb_mp.cu'],
209
                          include_dirs=[os.path.join(this_dir, 'csrc')],
210
211
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
212

lcskrishna's avatar
lcskrishna committed
213
        print ("INFO: Building syncbn extension.")
214
        ext_modules.append(
215
216
217
            CUDAExtension(name='syncbn',
                          sources=['csrc/syncbn.cpp',
                                   'csrc/welford.cu'],
218
                          include_dirs=[os.path.join(this_dir, 'csrc')],
219
220
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
221

222
        nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
223
224
225
        hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
        print ("INFO: Building fused layernorm extension.")
        ext_modules.append(
226
227
228
            CUDAExtension(name='fused_layer_norm_cuda',
                          sources=['csrc/layer_norm_cuda.cpp',
                                   'csrc/layer_norm_cuda_kernel.cu'],
229
                          include_dirs=[os.path.join(this_dir, 'csrc')],
230
231
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
232

233
234
        print ("INFO: Building the MLP Extension.")
        ext_modules.append(
235
236
237
            CUDAExtension(name='mlp_cuda',
                          sources=['csrc/mlp.cpp',
                                   'csrc/mlp_cuda.cu'],
238
                          include_dirs=[os.path.join(this_dir, 'csrc')],
239
240
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
241
242
243
244
245
246
        ext_modules.append(
            CUDAExtension(name='fused_dense_cuda',
                          sources=['csrc/fused_dense.cpp',
                                   'csrc/fused_dense_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
247
        """
Masaki Kozuki's avatar
Masaki Kozuki committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        ext_modules.append(
            CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
                                   'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))

        ext_modules.append(
            CUDAExtension(name='scaled_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_masked_softmax.cpp',
                                   'csrc/megatron/scaled_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))
271
        """
272

273
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
jjsjann123's avatar
jjsjann123 committed
274
    from torch.utils.cpp_extension import CUDAExtension
275
276
    if "--bnp" in sys.argv:
        sys.argv.remove("--bnp")
jjsjann123's avatar
jjsjann123 committed
277
278
279
280

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

281
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
282
        raise RuntimeError("--bnp was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
jjsjann123's avatar
jjsjann123 committed
283
284
285
286
287
288
289
    else:
        ext_modules.append(
            CUDAExtension(name='bnp',
                          sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
                                   'apex/contrib/csrc/groupbn/ipc.cu',
                                   'apex/contrib/csrc/groupbn/interface.cpp',
                                   'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
290
291
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
mcarilli's avatar
mcarilli committed
292
                          extra_compile_args={'cxx': [] + version_dependent_macros,
jjsjann123's avatar
jjsjann123 committed
293
294
295
                                              'nvcc':['-DCUDA_HAS_FP16=1',
                                                      '-D__CUDA_NO_HALF_OPERATORS__',
                                                      '-D__CUDA_NO_HALF_CONVERSIONS__',
296
                                                      '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
jjsjann123's avatar
jjsjann123 committed
297

298
if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
299
    from torch.utils.cpp_extension import CUDAExtension
300
301
    if "--xentropy" in sys.argv:
        sys.argv.remove("--xentropy")
302
303
304
305

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

306
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
307
308
        raise RuntimeError("--xentropy was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
309
310
        print ("INFO: Building the xentropy extension.")
        ext_modules.append(
311
312
313
            CUDAExtension(name='xentropy_cuda',
                          sources=['apex/contrib/csrc/xentropy/interface.cpp',
                                   'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
314
315
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/xentropy')],
316
317
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
318

319

320
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
321
    from torch.utils.cpp_extension import CUDAExtension
322
323
    if "--deprecated_fused_adam" in sys.argv:
        sys.argv.remove("--deprecated_fused_adam")
324
325
326
327

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

328
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
329
330
        raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
331
332
333
334
        print ("INFO: Building deprecated fused adam extension.")
        nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
        ext_modules.append(
335
336
337
            CUDAExtension(name='fused_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
338
339
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
340
341
342
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))

343
if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
344
    from torch.utils.cpp_extension import CUDAExtension
345
346
    if "--deprecated_fused_lamb" in sys.argv:
        sys.argv.remove("--deprecated_fused_lamb")
347
348
349
350

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

351
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
352
353
        raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
354
355
356
357
        print ("INFO: Building deprecated fused lamb extension.")
        nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
        ext_modules.append(
358
359
360
361
362
363
            CUDAExtension(name='fused_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
364

365
366
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
ptrblck's avatar
ptrblck committed
367
368
generator_flag = []
torch_dir = torch.__path__[0]
369
370
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
    generator_flag = ["-DOLD_GENERATOR_PATH"]
ptrblck's avatar
ptrblck committed
371

yjk21's avatar
yjk21 committed
372
373
374
if "--fast_layer_norm" in sys.argv:
    sys.argv.remove("--fast_layer_norm")

375
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
376
377
378
379
        raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
380
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        if int(bare_metal_major) >= 11:
            cc_flag.append('-gencode')
            cc_flag.append('arch=compute_80,code=sm_80')

        ext_modules.append(
            CUDAExtension(name='fast_layer_norm',
                          sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
                                   'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
                                   'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
396
397
398
399
400
                                                      '-U__CUDA_NO_BFLOAT16_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
                                                      '-U__CUDA_NO_BFLOAT162_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
                                                      '-I./apex/contrib/csrc/layer_norm/',
yjk21's avatar
yjk21 committed
401
402
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
403
404
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
yjk21's avatar
yjk21 committed
405
406
407
if "--fmha" in sys.argv:
    sys.argv.remove("--fmha")

408
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
409
410
411
412
        raise RuntimeError("--fmha was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
413
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
414
415
416
417
418
419
420
        if int(bare_metal_major) < 11:
            raise RuntimeError("--fmha only supported on SM80")

        ext_modules.append(
            CUDAExtension(name='fmhalib',
                          sources=[
                                   'apex/contrib/csrc/fmha/fmha_api.cpp',
yjk21's avatar
yjk21 committed
421
                                   'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
yjk21's avatar
yjk21 committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',
                                                      ] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_80,code=sm_80',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
439
440
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
yjk21's avatar
yjk21 committed
441

ptrblck's avatar
ptrblck committed
442

443
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
444
    from torch.utils.cpp_extension import CUDAExtension
445
446
    if "--fast_multihead_attn" in sys.argv:
        sys.argv.remove("--fast_multihead_attn")
447
448

    from torch.utils.cpp_extension import BuildExtension
449
    cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
450

451
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
452
453
        raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
ptrblck's avatar
ptrblck committed
454
455
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
456
457
458
459
460
        if not IS_ROCM_PYTORCH:
            _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
            if int(bare_metal_major) >= 11:
                cc_flag.append('-gencode')
                cc_flag.append('arch=compute_80,code=sm_80')
ptrblck's avatar
ptrblck committed
461

462
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        nvcc_args_mha = ['-O3',
                         '-gencode',
                         'arch=compute_70,code=sm_70',
                         '-Iapex/contrib/csrc/multihead_attn/cutlass',
                         '-U__CUDA_NO_HALF_OPERATORS__',
                         '-U__CUDA_NO_HALF_CONVERSIONS__',
                         '--expt-relaxed-constexpr',
                         '--expt-extended-lambda',
                         '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
        hipcc_args_mha = ['-O3',
                          '-Iapex/contrib/csrc/multihead_attn/cutlass',
                          '-I/opt/rocm/include/hiprand',
                          '-I/opt/rocm/include/rocrand',
                          '-U__HIP_NO_HALF_OPERATORS__',
                          '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
478

479
480
        ext_modules.append(
            CUDAExtension(name='fast_additive_mask_softmax_dropout',
481
                          sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp',
482
                                   'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
483
484
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
485
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
486
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
487
488
        ext_modules.append(
            CUDAExtension(name='fast_mask_softmax_dropout',
489
                          sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp',
490
                                   'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
491
492
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
493
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
494
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
495
496
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
497
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp',
498
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
499
500
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
501
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
502
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
503
504
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias',
505
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp',
506
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
507
508
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
509
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
510
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
511
512
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn',
513
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp',
514
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
515
516
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
517
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
518
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
519
520
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_norm_add',
521
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp',
522
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
523
524
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
525
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
526
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
527
528
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn',
529
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp',
530
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
531
532
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
533
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
534
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
535
536
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
537
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp',
538
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
539
540
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
541
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
542
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
543

544
545
546
if "--transducer" in sys.argv:
    sys.argv.remove("--transducer")

547
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
548
549
550
551
552
553
554
        raise RuntimeError("--transducer was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        ext_modules.append(
            CUDAExtension(name='transducer_joint_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
                                   'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
Masaki Kozuki's avatar
Masaki Kozuki committed
555
556
                                              'nvcc': ['-O3'] + version_dependent_macros},
                          include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
557
558
559
560
561
562
563
564
        ext_modules.append(
            CUDAExtension(name='transducer_loss_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
                                   'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))

565
566
567
if "--fast_bottleneck" in sys.argv:
    sys.argv.remove("--fast_bottleneck")

568
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
569
570
571
572
573
574
        raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
        ext_modules.append(
            CUDAExtension(name='fast_bottleneck',
                          sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
Masaki Kozuki's avatar
Masaki Kozuki committed
575
                          include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
576
577
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))

578
579
if "--cuda_ext" in sys.argv:
    sys.argv.remove("--cuda_ext")
580

Christian Sarofeen's avatar
Christian Sarofeen committed
581
setup(
582
583
    name='apex',
    version='0.1',
584
585
586
587
    packages=find_packages(exclude=('build',
                                    'csrc',
                                    'include',
                                    'tests',
588
589
590
591
592
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
Christian Sarofeen's avatar
Christian Sarofeen committed
593
    description='PyTorch Extensions written by NVIDIA',
jjsjann123's avatar
jjsjann123 committed
594
    ext_modules=ext_modules,
595
596
    cmdclass=cmdclass,
    #cmdclass={'build_ext': BuildExtension} if ext_modules else {},
ptrblck's avatar
ptrblck committed
597
    extras_require=extras,
Christian Sarofeen's avatar
Christian Sarofeen committed
598
)