setup.py 36.8 KB
Newer Older
1
import torch
Masaki Kozuki's avatar
Masaki Kozuki committed
2
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
3
from setuptools import setup, find_packages
mcarilli's avatar
mcarilli committed
4
import subprocess
5

jjsjann123's avatar
jjsjann123 committed
6
import sys
Marek Kolodziej's avatar
Marek Kolodziej committed
7
import warnings
mcarilli's avatar
mcarilli committed
8
import os
jjsjann123's avatar
jjsjann123 committed
9

10
11
12
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

ptrblck's avatar
ptrblck committed
13
14
15
16
17
18
19
20
21
22
def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor

Jithun Nair's avatar
Jithun Nair committed
23
24
25
26
print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

27
28
def check_if_rocm_pytorch():
    is_rocm_pytorch = False
Jithun Nair's avatar
Jithun Nair committed
29
    if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
30
31
32
33
34
35
36
37
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    return is_rocm_pytorch

IS_ROCM_PYTORCH = check_if_rocm_pytorch()

if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
mcarilli's avatar
mcarilli committed
38
39
40
41
42
43
    # https://github.com/NVIDIA/apex/issues/486
    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
ptrblck's avatar
ptrblck committed
44
45
          'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
          'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
mcarilli's avatar
mcarilli committed
46
47
48
          'If you wish to cross-compile for a single specific architecture,\n'
          'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
Masaki Kozuki's avatar
Masaki Kozuki committed
49
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
ptrblck's avatar
ptrblck committed
50
51
52
53
        if int(bare_metal_major) == 11:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
        else:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
54
55
56
57
58
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for the same gfx targets\n'
          'used by default in ROCm PyTorch\n')
59

60
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
Michael Carilli's avatar
Michael Carilli committed
61
      raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
62
63
                         "The latest stable release can be obtained from https://pytorch.org/")

jjsjann123's avatar
jjsjann123 committed
64
65
66
cmdclass = {}
ext_modules = []

ptrblck's avatar
ptrblck committed
67
extras = {}
Marek Kolodziej's avatar
Marek Kolodziej committed
68
if "--pyprof" in sys.argv:
69
70
71
72
73
    string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
             "soon be removed from Apex.  Please visit\n" + \
             "https://github.com/NVIDIA/PyProf\n" + \
             "for the latest version."
    warnings.warn(string, DeprecationWarning)
Marek Kolodziej's avatar
Marek Kolodziej committed
74
75
    with open('requirements.txt') as f:
        required_packages = f.read().splitlines()
ptrblck's avatar
ptrblck committed
76
        extras['pyprof'] = required_packages
Marek Kolodziej's avatar
Marek Kolodziej committed
77
78
79
80
81
82
83
    try:
        sys.argv.remove("--pyprof")
    except:
        pass
else:
    warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")

84
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
Michael Carilli's avatar
Michael Carilli committed
85
86
    if TORCH_MAJOR == 0:
        raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
87
                           "found torch.__version__ = {}".format(torch.__version__))
88
89
90
91
92
93
94

if "--cpp_ext" in sys.argv:
    sys.argv.remove("--cpp_ext")
    ext_modules.append(
        CppExtension('apex_C',
                     ['csrc/flatten_unflatten.cpp',]))

ptrblck's avatar
ptrblck committed
95
def get_cuda_bare_metal_version(cuda_dir):
mcarilli's avatar
mcarilli committed
96
97
98
99
100
101
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]
ptrblck's avatar
ptrblck committed
102
103
104
105
106

    return raw_output, bare_metal_major, bare_metal_minor

def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
mcarilli's avatar
mcarilli committed
107
108
109
110
111
112
113
    torch_binary_major = torch.version.cuda.split(".")[0]
    torch_binary_minor = torch.version.cuda.split(".")[1]

    print("\nCompiling cuda extensions with")
    print(raw_output + "from " + cuda_dir + "/bin\n")

    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
Michael Carilli's avatar
Michael Carilli committed
114
115
116
117
118
119
        raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
                           "not match the version used to compile Pytorch binaries.  " +
                           "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
                           "In some cases, a minor-version mismatch will not cause later errors:  " +
                           "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
                           "You can try commenting out this check (at your own risk).")
mcarilli's avatar
mcarilli committed
120

mcarilli's avatar
mcarilli committed
121
122
123
124
125
126
127
128
129
130
131
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
    version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
    version_ge_1_3 = ['-DVERSION_GE_1_3']
132
133
134
135
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
    version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
mcarilli's avatar
mcarilli committed
136

137
138
139
if "--distributed_adam" in sys.argv:
    sys.argv.remove("--distributed_adam")

Masaki Kozuki's avatar
Masaki Kozuki committed
140
    if CUDA_HOME is None:
141
142
143
144
145
146
147
148
149
150
151
        raise RuntimeError("--distributed_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        ext_modules.append(
            CUDAExtension(name='distributed_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '--use_fast_math'] + version_dependent_macros}))

152
153
154
155
156
157
if "--distributed_lamb" in sys.argv:
    sys.argv.remove("--distributed_lamb")

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

158
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
159
160
        raise RuntimeError("--distributed_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
161
162
163
        print ("INFO: Building the distributed_lamb extension.")
        nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros
164
165
166
167
168
169
        ext_modules.append(
            CUDAExtension(name='distributed_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
170
                                              'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
171

jjsjann123's avatar
jjsjann123 committed
172
173
if "--cuda_ext" in sys.argv:
    sys.argv.remove("--cuda_ext")
174

175
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
Michael Carilli's avatar
Michael Carilli committed
176
        raise RuntimeError("--cuda_ext was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
177
    else:
178
        if not IS_ROCM_PYTORCH:
179
180
            check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)

181
182
183
184
        print ("INFO: Building the multi-tensor apply extension.")
        nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
        ext_modules.append(
185
186
187
188
189
190
            CUDAExtension(name='amp_C',
                          sources=['csrc/amp_C_frontend.cpp',
                                   'csrc/multi_tensor_sgd_kernel.cu',
                                   'csrc/multi_tensor_scale_kernel.cu',
                                   'csrc/multi_tensor_axpby_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu',
191
                                   'csrc/multi_tensor_l2norm_scale_kernel.cu',
192
193
194
195
196
197
                                   'csrc/multi_tensor_lamb_stage_1.cu',
                                   'csrc/multi_tensor_lamb_stage_2.cu',
                                   'csrc/multi_tensor_adam.cu',
                                   'csrc/multi_tensor_adagrad.cu',
                                   'csrc/multi_tensor_novograd.cu',
                                   'csrc/multi_tensor_lamb.cu'],
198
                          include_dirs=[os.path.join(this_dir, 'csrc')],
199
200
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
201

lcskrishna's avatar
lcskrishna committed
202
        print ("INFO: Building syncbn extension.")
203
        ext_modules.append(
204
205
206
            CUDAExtension(name='syncbn',
                          sources=['csrc/syncbn.cpp',
                                   'csrc/welford.cu'],
207
                          include_dirs=[os.path.join(this_dir, 'csrc')],
208
209
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
210

211
        nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
212
213
214
        hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
        print ("INFO: Building fused layernorm extension.")
        ext_modules.append(
215
216
217
            CUDAExtension(name='fused_layer_norm_cuda',
                          sources=['csrc/layer_norm_cuda.cpp',
                                   'csrc/layer_norm_cuda_kernel.cu'],
218
                          include_dirs=[os.path.join(this_dir, 'csrc')],
219
220
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
221

222
223
        print ("INFO: Building the MLP Extension.")
        ext_modules.append(
224
225
226
            CUDAExtension(name='mlp_cuda',
                          sources=['csrc/mlp.cpp',
                                   'csrc/mlp_cuda.cu'],
227
                          include_dirs=[os.path.join(this_dir, 'csrc')],
228
229
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
230
231
232
233
234
235
        ext_modules.append(
            CUDAExtension(name='fused_dense_cuda',
                          sources=['csrc/fused_dense.cpp',
                                   'csrc/fused_dense_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
236

Masaki Kozuki's avatar
Masaki Kozuki committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        ext_modules.append(
            CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
                                   'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))

        ext_modules.append(
            CUDAExtension(name='scaled_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_masked_softmax.cpp',
                                   'csrc/megatron/scaled_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))
260

jjsjann123's avatar
jjsjann123 committed
261
262
263
264
265
266
if "--bnp" in sys.argv:
    sys.argv.remove("--bnp")

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

267
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
268
        raise RuntimeError("--bnp was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
jjsjann123's avatar
jjsjann123 committed
269
270
271
272
273
274
275
    else:
        ext_modules.append(
            CUDAExtension(name='bnp',
                          sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
                                   'apex/contrib/csrc/groupbn/ipc.cu',
                                   'apex/contrib/csrc/groupbn/interface.cpp',
                                   'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
276
277
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
mcarilli's avatar
mcarilli committed
278
                          extra_compile_args={'cxx': [] + version_dependent_macros,
jjsjann123's avatar
jjsjann123 committed
279
280
281
                                              'nvcc':['-DCUDA_HAS_FP16=1',
                                                      '-D__CUDA_NO_HALF_OPERATORS__',
                                                      '-D__CUDA_NO_HALF_CONVERSIONS__',
282
                                                      '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
jjsjann123's avatar
jjsjann123 committed
283

284
285
286
287
288
289
if "--xentropy" in sys.argv:
    sys.argv.remove("--xentropy")

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

290
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
291
292
        raise RuntimeError("--xentropy was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
293
294
        print ("INFO: Building the xentropy extension.")
        ext_modules.append(
295
296
297
298
299
300
            CUDAExtension(name='xentropy_cuda',
                          sources=['apex/contrib/csrc/xentropy/interface.cpp',
                                   'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
301

302

303
304
305
306
307
308
if "--deprecated_fused_adam" in sys.argv:
    sys.argv.remove("--deprecated_fused_adam")

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

309
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
310
311
        raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
312
313
314
315
        print ("INFO: Building deprecated fused adam extension.")
        nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
        ext_modules.append(
316
317
318
319
320
321
322
            CUDAExtension(name='fused_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))

323
324
325
326
327
328
if "--deprecated_fused_lamb" in sys.argv:
    sys.argv.remove("--deprecated_fused_lamb")

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

329
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
330
331
        raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
332
333
334
335
        print ("INFO: Building deprecated fused lamb extension.")
        nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
        ext_modules.append(
336
337
338
339
340
341
            CUDAExtension(name='fused_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
342

343
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
ptrblck's avatar
ptrblck committed
344
345
346
347
348
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
    generator_flag = ['-DOLD_GENERATOR']

yjk21's avatar
yjk21 committed
349
350
351
if "--fast_layer_norm" in sys.argv:
    sys.argv.remove("--fast_layer_norm")

Masaki Kozuki's avatar
Masaki Kozuki committed
352
    if CUDA_HOME is None:
yjk21's avatar
yjk21 committed
353
354
355
356
        raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
357
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        if int(bare_metal_major) >= 11:
            cc_flag.append('-gencode')
            cc_flag.append('arch=compute_80,code=sm_80')

        ext_modules.append(
            CUDAExtension(name='fast_layer_norm',
                          sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
                                   'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
                                   'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
375
376
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
yjk21's avatar
yjk21 committed
377
378
379
if "--fmha" in sys.argv:
    sys.argv.remove("--fmha")

Masaki Kozuki's avatar
Masaki Kozuki committed
380
    if CUDA_HOME is None:
yjk21's avatar
yjk21 committed
381
382
383
384
        raise RuntimeError("--fmha was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
385
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
386
387
388
389
390
391
392
        if int(bare_metal_major) < 11:
            raise RuntimeError("--fmha only supported on SM80")

        ext_modules.append(
            CUDAExtension(name='fmhalib',
                          sources=[
                                   'apex/contrib/csrc/fmha/fmha_api.cpp',
yjk21's avatar
yjk21 committed
393
                                   'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
yjk21's avatar
yjk21 committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',
                                                      ] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_80,code=sm_80',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
411
412
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
yjk21's avatar
yjk21 committed
413

ptrblck's avatar
ptrblck committed
414

415
416
417
if "--fast_multihead_attn" in sys.argv:
    sys.argv.remove("--fast_multihead_attn")

Masaki Kozuki's avatar
Masaki Kozuki committed
418
    if CUDA_HOME is None:
419
420
        raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
ptrblck's avatar
ptrblck committed
421
422
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
423
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
ptrblck's avatar
ptrblck committed
424
425
426
427
        if int(bare_metal_major) >= 11:
            cc_flag.append('-gencode')
            cc_flag.append('arch=compute_80,code=sm_80')

428
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
429
430
431
432
433
434
435
436
437
438
439
        ext_modules.append(
            CUDAExtension(name='fast_additive_mask_softmax_dropout',
                          sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
                                   'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
440
441
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
442
443
444
445
446
447
448
449
450
451
452
        ext_modules.append(
            CUDAExtension(name='fast_mask_softmax_dropout',
                          sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
                                   'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
453
454
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
455
456
457
458
459
460
461
462
463
464
465
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
466
467
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
468
469
470
471
472
473
474
475
476
477
478
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias',
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
479
480
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
481
482
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn',
483
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
484
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
ptrblck's avatar
ptrblck committed
485
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
486
                                              'nvcc':['-O3',
487
                                                      '-gencode', 'arch=compute_70,code=sm_70',
488
489
490
491
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
492
493
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
494
495
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_norm_add',
496
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
497
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
ptrblck's avatar
ptrblck committed
498
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
499
                                              'nvcc':['-O3',
500
                                                      '-gencode', 'arch=compute_70,code=sm_70',
501
502
503
504
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
505
506
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
507
508
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn',
509
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
510
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
ptrblck's avatar
ptrblck committed
511
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
512
                                              'nvcc':['-O3',
513
                                                      '-gencode', 'arch=compute_70,code=sm_70',
514
515
516
517
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
518
519
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
520
521
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
522
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
523
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
ptrblck's avatar
ptrblck committed
524
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
525
                                              'nvcc':['-O3',
526
                                                      '-gencode', 'arch=compute_70,code=sm_70',
527
528
529
530
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
531
532
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
533

534
535
536
if "--transducer" in sys.argv:
    sys.argv.remove("--transducer")

Masaki Kozuki's avatar
Masaki Kozuki committed
537
    if CUDA_HOME is None:
538
539
540
541
542
543
544
        raise RuntimeError("--transducer was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        ext_modules.append(
            CUDAExtension(name='transducer_joint_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
                                   'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
Masaki Kozuki's avatar
Masaki Kozuki committed
545
546
                                              'nvcc': ['-O3'] + version_dependent_macros},
                          include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
547
548
549
550
551
552
553
554
        ext_modules.append(
            CUDAExtension(name='transducer_loss_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
                                   'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))

555
556
557
if "--fast_bottleneck" in sys.argv:
    sys.argv.remove("--fast_bottleneck")

Masaki Kozuki's avatar
Masaki Kozuki committed
558
    if CUDA_HOME is None:
559
560
561
562
563
564
        raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
        ext_modules.append(
            CUDAExtension(name='fast_bottleneck',
                          sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
Masaki Kozuki's avatar
Masaki Kozuki committed
565
                          include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
566
567
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))

568

Christian Sarofeen's avatar
Christian Sarofeen committed
569
setup(
570
571
    name='apex',
    version='0.1',
572
573
574
575
    packages=find_packages(exclude=('build',
                                    'csrc',
                                    'include',
                                    'tests',
576
577
578
579
580
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
Christian Sarofeen's avatar
Christian Sarofeen committed
581
    description='PyTorch Extensions written by NVIDIA',
jjsjann123's avatar
jjsjann123 committed
582
    ext_modules=ext_modules,
Masaki Kozuki's avatar
Masaki Kozuki committed
583
    cmdclass={'build_ext': BuildExtension} if ext_modules else {},
ptrblck's avatar
ptrblck committed
584
    extras_require=extras,
Christian Sarofeen's avatar
Christian Sarofeen committed
585
)