setup.py 36.5 KB
Newer Older
1
import torch
Masaki Kozuki's avatar
Masaki Kozuki committed
2
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
3
from setuptools import setup, find_packages
mcarilli's avatar
mcarilli committed
4
import subprocess
5

jjsjann123's avatar
jjsjann123 committed
6
import sys
Marek Kolodziej's avatar
Marek Kolodziej committed
7
import warnings
mcarilli's avatar
mcarilli committed
8
import os
jjsjann123's avatar
jjsjann123 committed
9

10
11
12
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

ptrblck's avatar
ptrblck committed
13
14
15
16
17
18
19
20
21
22
def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor

Jithun Nair's avatar
Jithun Nair committed
23
24
25
26
print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

27
28
def check_if_rocm_pytorch():
    is_rocm_pytorch = False
Jithun Nair's avatar
Jithun Nair committed
29
    if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
30
31
32
33
34
35
36
37
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    return is_rocm_pytorch

IS_ROCM_PYTORCH = check_if_rocm_pytorch()

if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
mcarilli's avatar
mcarilli committed
38
39
40
41
42
43
    # https://github.com/NVIDIA/apex/issues/486
    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
ptrblck's avatar
ptrblck committed
44
45
          'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
          'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
mcarilli's avatar
mcarilli committed
46
47
48
          'If you wish to cross-compile for a single specific architecture,\n'
          'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
Masaki Kozuki's avatar
Masaki Kozuki committed
49
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
ptrblck's avatar
ptrblck committed
50
51
52
53
        if int(bare_metal_major) == 11:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
        else:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
54
55
56
57
58
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for the same gfx targets\n'
          'used by default in ROCm PyTorch\n')
59

60
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
Michael Carilli's avatar
Michael Carilli committed
61
      raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
62
63
                         "The latest stable release can be obtained from https://pytorch.org/")

jjsjann123's avatar
jjsjann123 committed
64
65
66
cmdclass = {}
ext_modules = []

ptrblck's avatar
ptrblck committed
67
extras = {}
Marek Kolodziej's avatar
Marek Kolodziej committed
68
if "--pyprof" in sys.argv:
69
70
71
72
73
    string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
             "soon be removed from Apex.  Please visit\n" + \
             "https://github.com/NVIDIA/PyProf\n" + \
             "for the latest version."
    warnings.warn(string, DeprecationWarning)
Marek Kolodziej's avatar
Marek Kolodziej committed
74
75
    with open('requirements.txt') as f:
        required_packages = f.read().splitlines()
ptrblck's avatar
ptrblck committed
76
        extras['pyprof'] = required_packages
Marek Kolodziej's avatar
Marek Kolodziej committed
77
78
79
80
81
82
83
    try:
        sys.argv.remove("--pyprof")
    except:
        pass
else:
    warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")

84
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
Michael Carilli's avatar
Michael Carilli committed
85
86
    if TORCH_MAJOR == 0:
        raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
87
                           "found torch.__version__ = {}".format(torch.__version__))
88
    cmdclass['build_ext'] = BuildExtension
89
90
91
92
93
94
if "--cpp_ext" in sys.argv:
    sys.argv.remove("--cpp_ext")
    ext_modules.append(
        CppExtension('apex_C',
                     ['csrc/flatten_unflatten.cpp',]))

ptrblck's avatar
ptrblck committed
95
def get_cuda_bare_metal_version(cuda_dir):
mcarilli's avatar
mcarilli committed
96
97
98
99
100
101
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]
ptrblck's avatar
ptrblck committed
102
103
104
105
106

    return raw_output, bare_metal_major, bare_metal_minor

def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
mcarilli's avatar
mcarilli committed
107
108
109
110
111
112
113
    torch_binary_major = torch.version.cuda.split(".")[0]
    torch_binary_minor = torch.version.cuda.split(".")[1]

    print("\nCompiling cuda extensions with")
    print(raw_output + "from " + cuda_dir + "/bin\n")

    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
Michael Carilli's avatar
Michael Carilli committed
114
115
116
117
118
119
        raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
                           "not match the version used to compile Pytorch binaries.  " +
                           "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
                           "In some cases, a minor-version mismatch will not cause later errors:  " +
                           "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
                           "You can try commenting out this check (at your own risk).")
mcarilli's avatar
mcarilli committed
120

mcarilli's avatar
mcarilli committed
121
122
123
124
125
126
127
128
129
130
131
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
    version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
    version_ge_1_3 = ['-DVERSION_GE_1_3']
132
133
134
135
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
    version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
mcarilli's avatar
mcarilli committed
136

137
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
138
    from torch.utils.cpp_extension import CUDAExtension
139
140
    if "--distributed_adam" in sys.argv:
        sys.argv.remove("--distributed_adam")
141
142
143
144

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

145
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
146
147
        raise RuntimeError("--distributed_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
148
149
        nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_adam = ['-O3'] + version_dependent_macros
150
151
152
153
        ext_modules.append(
            CUDAExtension(name='distributed_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
154
155
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
156
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
157
                                              'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
158

159
if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
160
    from torch.utils.cpp_extension import CUDAExtension
161
162
    if "--distributed_lamb" in sys.argv:
        sys.argv.remove("--distributed_lamb")
163
164
165
166

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

167
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
168
169
        raise RuntimeError("--distributed_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
170
171
172
        print ("INFO: Building the distributed_lamb extension.")
        nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros
173
174
175
176
177
178
        ext_modules.append(
            CUDAExtension(name='distributed_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
179
                                              'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
180

jjsjann123's avatar
jjsjann123 committed
181
if "--cuda_ext" in sys.argv:
182
    from torch.utils.cpp_extension import CUDAExtension
183

184
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
Michael Carilli's avatar
Michael Carilli committed
185
        raise RuntimeError("--cuda_ext was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
186
    else:
187
        if not IS_ROCM_PYTORCH:
188
189
            check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)

190
191
192
193
        print ("INFO: Building the multi-tensor apply extension.")
        nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
        ext_modules.append(
194
195
196
197
198
199
            CUDAExtension(name='amp_C',
                          sources=['csrc/amp_C_frontend.cpp',
                                   'csrc/multi_tensor_sgd_kernel.cu',
                                   'csrc/multi_tensor_scale_kernel.cu',
                                   'csrc/multi_tensor_axpby_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu',
200
                                   'csrc/multi_tensor_l2norm_scale_kernel.cu',
201
202
203
204
205
206
                                   'csrc/multi_tensor_lamb_stage_1.cu',
                                   'csrc/multi_tensor_lamb_stage_2.cu',
                                   'csrc/multi_tensor_adam.cu',
                                   'csrc/multi_tensor_adagrad.cu',
                                   'csrc/multi_tensor_novograd.cu',
                                   'csrc/multi_tensor_lamb.cu'],
207
                          include_dirs=[os.path.join(this_dir, 'csrc')],
208
209
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
210

lcskrishna's avatar
lcskrishna committed
211
        print ("INFO: Building syncbn extension.")
212
        ext_modules.append(
213
214
215
            CUDAExtension(name='syncbn',
                          sources=['csrc/syncbn.cpp',
                                   'csrc/welford.cu'],
216
                          include_dirs=[os.path.join(this_dir, 'csrc')],
217
218
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
219

220
        nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
221
222
223
        hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
        print ("INFO: Building fused layernorm extension.")
        ext_modules.append(
224
225
226
            CUDAExtension(name='fused_layer_norm_cuda',
                          sources=['csrc/layer_norm_cuda.cpp',
                                   'csrc/layer_norm_cuda_kernel.cu'],
227
                          include_dirs=[os.path.join(this_dir, 'csrc')],
228
229
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
230

231
232
        print ("INFO: Building the MLP Extension.")
        ext_modules.append(
233
234
235
            CUDAExtension(name='mlp_cuda',
                          sources=['csrc/mlp.cpp',
                                   'csrc/mlp_cuda.cu'],
236
                          include_dirs=[os.path.join(this_dir, 'csrc')],
237
238
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
239
240
241
242
243
244
        ext_modules.append(
            CUDAExtension(name='fused_dense_cuda',
                          sources=['csrc/fused_dense.cpp',
                                   'csrc/fused_dense_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
245
        """
Masaki Kozuki's avatar
Masaki Kozuki committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        ext_modules.append(
            CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
                                   'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))

        ext_modules.append(
            CUDAExtension(name='scaled_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_masked_softmax.cpp',
                                   'csrc/megatron/scaled_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))
269
        """
270

271
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
jjsjann123's avatar
jjsjann123 committed
272
    from torch.utils.cpp_extension import CUDAExtension
273
274
    if "--bnp" in sys.argv:
        sys.argv.remove("--bnp")
jjsjann123's avatar
jjsjann123 committed
275
276
277
278

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

279
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
280
        raise RuntimeError("--bnp was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
jjsjann123's avatar
jjsjann123 committed
281
282
283
284
285
286
287
    else:
        ext_modules.append(
            CUDAExtension(name='bnp',
                          sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
                                   'apex/contrib/csrc/groupbn/ipc.cu',
                                   'apex/contrib/csrc/groupbn/interface.cpp',
                                   'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
288
289
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
mcarilli's avatar
mcarilli committed
290
                          extra_compile_args={'cxx': [] + version_dependent_macros,
jjsjann123's avatar
jjsjann123 committed
291
292
293
                                              'nvcc':['-DCUDA_HAS_FP16=1',
                                                      '-D__CUDA_NO_HALF_OPERATORS__',
                                                      '-D__CUDA_NO_HALF_CONVERSIONS__',
294
                                                      '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
jjsjann123's avatar
jjsjann123 committed
295

296
if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
297
    from torch.utils.cpp_extension import CUDAExtension
298
299
    if "--xentropy" in sys.argv:
        sys.argv.remove("--xentropy")
300
301
302
303

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

304
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
305
306
        raise RuntimeError("--xentropy was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
307
308
        print ("INFO: Building the xentropy extension.")
        ext_modules.append(
309
310
311
            CUDAExtension(name='xentropy_cuda',
                          sources=['apex/contrib/csrc/xentropy/interface.cpp',
                                   'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
312
313
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/xentropy')],
314
315
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
316

317

318
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
319
    from torch.utils.cpp_extension import CUDAExtension
320
321
    if "--deprecated_fused_adam" in sys.argv:
        sys.argv.remove("--deprecated_fused_adam")
322
323
324
325

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

326
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
327
328
        raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
329
330
331
332
        print ("INFO: Building deprecated fused adam extension.")
        nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
        ext_modules.append(
333
334
335
            CUDAExtension(name='fused_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
336
337
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
338
339
340
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))

341
if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
342
    from torch.utils.cpp_extension import CUDAExtension
343
344
    if "--deprecated_fused_lamb" in sys.argv:
        sys.argv.remove("--deprecated_fused_lamb")
345
346
347
348

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

349
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
350
351
        raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
352
353
354
355
        print ("INFO: Building deprecated fused lamb extension.")
        nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
        ext_modules.append(
356
357
358
359
360
361
            CUDAExtension(name='fused_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
362

363
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
ptrblck's avatar
ptrblck committed
364
365
366
367
368
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
    generator_flag = ['-DOLD_GENERATOR']

yjk21's avatar
yjk21 committed
369
370
371
if "--fast_layer_norm" in sys.argv:
    sys.argv.remove("--fast_layer_norm")

372
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
373
374
375
376
        raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
377
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        if int(bare_metal_major) >= 11:
            cc_flag.append('-gencode')
            cc_flag.append('arch=compute_80,code=sm_80')

        ext_modules.append(
            CUDAExtension(name='fast_layer_norm',
                          sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
                                   'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
                                   'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
393
394
395
396
397
                                                      '-U__CUDA_NO_BFLOAT16_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
                                                      '-U__CUDA_NO_BFLOAT162_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
                                                      '-I./apex/contrib/csrc/layer_norm/',
yjk21's avatar
yjk21 committed
398
399
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
400
401
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
yjk21's avatar
yjk21 committed
402
403
404
if "--fmha" in sys.argv:
    sys.argv.remove("--fmha")

405
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
406
407
408
409
        raise RuntimeError("--fmha was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
410
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
411
412
413
414
415
416
417
        if int(bare_metal_major) < 11:
            raise RuntimeError("--fmha only supported on SM80")

        ext_modules.append(
            CUDAExtension(name='fmhalib',
                          sources=[
                                   'apex/contrib/csrc/fmha/fmha_api.cpp',
yjk21's avatar
yjk21 committed
418
                                   'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
yjk21's avatar
yjk21 committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',
                                                      ] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_80,code=sm_80',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
436
437
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
yjk21's avatar
yjk21 committed
438

ptrblck's avatar
ptrblck committed
439

440
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
441
    from torch.utils.cpp_extension import CUDAExtension
442
443
    if "--fast_multihead_attn" in sys.argv:
        sys.argv.remove("--fast_multihead_attn")
444
445

    from torch.utils.cpp_extension import BuildExtension
446
    cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
447

448
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
449
450
        raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
ptrblck's avatar
ptrblck committed
451
452
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
453
454
455
456
457
        if not IS_ROCM_PYTORCH:
            _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
            if int(bare_metal_major) >= 11:
                cc_flag.append('-gencode')
                cc_flag.append('arch=compute_80,code=sm_80')
ptrblck's avatar
ptrblck committed
458

459
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        nvcc_args_mha = ['-O3',
                         '-gencode',
                         'arch=compute_70,code=sm_70',
                         '-Iapex/contrib/csrc/multihead_attn/cutlass',
                         '-U__CUDA_NO_HALF_OPERATORS__',
                         '-U__CUDA_NO_HALF_CONVERSIONS__',
                         '--expt-relaxed-constexpr',
                         '--expt-extended-lambda',
                         '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
        hipcc_args_mha = ['-O3',
                          '-Iapex/contrib/csrc/multihead_attn/cutlass',
                          '-I/opt/rocm/include/hiprand',
                          '-I/opt/rocm/include/rocrand',
                          '-U__HIP_NO_HALF_OPERATORS__',
                          '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
475

476
477
        ext_modules.append(
            CUDAExtension(name='fast_additive_mask_softmax_dropout',
478
                          sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp',
479
                                   'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
480
481
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
482
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
483
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
484
485
        ext_modules.append(
            CUDAExtension(name='fast_mask_softmax_dropout',
486
                          sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp',
487
                                   'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
488
489
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
490
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
491
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
492
493
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
494
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp',
495
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
496
497
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
498
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
499
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
500
501
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias',
502
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp',
503
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
504
505
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
506
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
507
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
508
509
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn',
510
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp',
511
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
512
513
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
514
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
515
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
516
517
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_norm_add',
518
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp',
519
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
520
521
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
522
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
523
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
524
525
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn',
526
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp',
527
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
528
529
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
530
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
531
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
532
533
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
534
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp',
535
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
536
537
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
538
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
539
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
540

541
542
543
if "--transducer" in sys.argv:
    sys.argv.remove("--transducer")

544
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
545
546
547
548
549
550
551
        raise RuntimeError("--transducer was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        ext_modules.append(
            CUDAExtension(name='transducer_joint_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
                                   'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
Masaki Kozuki's avatar
Masaki Kozuki committed
552
553
                                              'nvcc': ['-O3'] + version_dependent_macros},
                          include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
554
555
556
557
558
559
560
561
        ext_modules.append(
            CUDAExtension(name='transducer_loss_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
                                   'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))

562
563
564
if "--fast_bottleneck" in sys.argv:
    sys.argv.remove("--fast_bottleneck")

565
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
566
567
568
569
570
571
        raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
        ext_modules.append(
            CUDAExtension(name='fast_bottleneck',
                          sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
Masaki Kozuki's avatar
Masaki Kozuki committed
572
                          include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
573
574
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))

575
576
if "--cuda_ext" in sys.argv:
    sys.argv.remove("--cuda_ext")
577

Christian Sarofeen's avatar
Christian Sarofeen committed
578
setup(
579
580
    name='apex',
    version='0.1',
581
582
583
584
    packages=find_packages(exclude=('build',
                                    'csrc',
                                    'include',
                                    'tests',
585
586
587
588
589
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
Christian Sarofeen's avatar
Christian Sarofeen committed
590
    description='PyTorch Extensions written by NVIDIA',
jjsjann123's avatar
jjsjann123 committed
591
    ext_modules=ext_modules,
592
593
    cmdclass=cmdclass,
    #cmdclass={'build_ext': BuildExtension} if ext_modules else {},
ptrblck's avatar
ptrblck committed
594
    extras_require=extras,
Christian Sarofeen's avatar
Christian Sarofeen committed
595
)