setup.py 36.7 KB
Newer Older
1
import torch
Masaki Kozuki's avatar
Masaki Kozuki committed
2
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
3
from setuptools import setup, find_packages
mcarilli's avatar
mcarilli committed
4
import subprocess
5

jjsjann123's avatar
jjsjann123 committed
6
import sys
Marek Kolodziej's avatar
Marek Kolodziej committed
7
import warnings
mcarilli's avatar
mcarilli committed
8
import os
jjsjann123's avatar
jjsjann123 committed
9

10
11
12
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

ptrblck's avatar
ptrblck committed
13
14
15
16
17
18
19
20
21
22
def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]

    return raw_output, bare_metal_major, bare_metal_minor

Jithun Nair's avatar
Jithun Nair committed
23
24
25
26
print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

27
28
def check_if_rocm_pytorch():
    is_rocm_pytorch = False
Jithun Nair's avatar
Jithun Nair committed
29
    if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
30
31
32
33
34
35
36
37
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    return is_rocm_pytorch

IS_ROCM_PYTORCH = check_if_rocm_pytorch()

if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
mcarilli's avatar
mcarilli committed
38
39
40
41
42
43
    # https://github.com/NVIDIA/apex/issues/486
    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
ptrblck's avatar
ptrblck committed
44
45
          'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
          'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
mcarilli's avatar
mcarilli committed
46
47
48
          'If you wish to cross-compile for a single specific architecture,\n'
          'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
Masaki Kozuki's avatar
Masaki Kozuki committed
49
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
ptrblck's avatar
ptrblck committed
50
51
52
53
        if int(bare_metal_major) == 11:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
        else:
            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
54
55
56
57
58
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
    print('\nWarning: Torch did not find available GPUs on this system.\n',
          'If your intention is to cross-compile, this is not an error.\n'
          'By default, Apex will cross-compile for the same gfx targets\n'
          'used by default in ROCm PyTorch\n')
59

60
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
Michael Carilli's avatar
Michael Carilli committed
61
      raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
62
63
                         "The latest stable release can be obtained from https://pytorch.org/")

jjsjann123's avatar
jjsjann123 committed
64
65
66
cmdclass = {}
ext_modules = []

ptrblck's avatar
ptrblck committed
67
extras = {}
Marek Kolodziej's avatar
Marek Kolodziej committed
68
if "--pyprof" in sys.argv:
69
70
71
72
73
    string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
             "soon be removed from Apex.  Please visit\n" + \
             "https://github.com/NVIDIA/PyProf\n" + \
             "for the latest version."
    warnings.warn(string, DeprecationWarning)
Marek Kolodziej's avatar
Marek Kolodziej committed
74
75
    with open('requirements.txt') as f:
        required_packages = f.read().splitlines()
ptrblck's avatar
ptrblck committed
76
        extras['pyprof'] = required_packages
Marek Kolodziej's avatar
Marek Kolodziej committed
77
78
79
80
81
82
83
    try:
        sys.argv.remove("--pyprof")
    except:
        pass
else:
    warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")

84
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
Michael Carilli's avatar
Michael Carilli committed
85
86
    if TORCH_MAJOR == 0:
        raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
87
                           "found torch.__version__ = {}".format(torch.__version__))
88
    cmdclass['build_ext'] = BuildExtension
89
90
91
92
93
94
if "--cpp_ext" in sys.argv:
    sys.argv.remove("--cpp_ext")
    ext_modules.append(
        CppExtension('apex_C',
                     ['csrc/flatten_unflatten.cpp',]))

ptrblck's avatar
ptrblck committed
95
def get_cuda_bare_metal_version(cuda_dir):
mcarilli's avatar
mcarilli committed
96
97
98
99
100
101
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = release[0]
    bare_metal_minor = release[1][0]
ptrblck's avatar
ptrblck committed
102
103
104
105
106

    return raw_output, bare_metal_major, bare_metal_minor

def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
mcarilli's avatar
mcarilli committed
107
108
109
110
111
112
113
    torch_binary_major = torch.version.cuda.split(".")[0]
    torch_binary_minor = torch.version.cuda.split(".")[1]

    print("\nCompiling cuda extensions with")
    print(raw_output + "from " + cuda_dir + "/bin\n")

    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
Michael Carilli's avatar
Michael Carilli committed
114
115
116
117
118
119
        raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
                           "not match the version used to compile Pytorch binaries.  " +
                           "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
                           "In some cases, a minor-version mismatch will not cause later errors:  " +
                           "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
                           "You can try commenting out this check (at your own risk).")
mcarilli's avatar
mcarilli committed
120

mcarilli's avatar
mcarilli committed
121
122
123
124
125
126
127
128
129
130
131
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
    version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
    version_ge_1_3 = ['-DVERSION_GE_1_3']
132
133
134
135
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
    version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
mcarilli's avatar
mcarilli committed
136

137
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
138
    from torch.utils.cpp_extension import CUDAExtension
139
140
    if "--distributed_adam" in sys.argv:
        sys.argv.remove("--distributed_adam")
141
142
143
144

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

145
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
146
147
        raise RuntimeError("--distributed_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
148
149
        nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_adam = ['-O3'] + version_dependent_macros
150
151
152
153
        ext_modules.append(
            CUDAExtension(name='distributed_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
154
155
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
156
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
157
                                              'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
158

159
if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
160
    from torch.utils.cpp_extension import CUDAExtension
161
162
    if "--distributed_lamb" in sys.argv:
        sys.argv.remove("--distributed_lamb")
163
164
165
166

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

167
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
168
169
        raise RuntimeError("--distributed_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
170
171
172
        print ("INFO: Building the distributed_lamb extension.")
        nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros
173
174
175
176
177
178
        ext_modules.append(
            CUDAExtension(name='distributed_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
179
                                              'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
180

jjsjann123's avatar
jjsjann123 committed
181
if "--cuda_ext" in sys.argv:
182
    from torch.utils.cpp_extension import CUDAExtension
183

184
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
Michael Carilli's avatar
Michael Carilli committed
185
        raise RuntimeError("--cuda_ext was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
186
    else:
187
        if not IS_ROCM_PYTORCH:
188
189
            check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)

190
191
192
193
        print ("INFO: Building the multi-tensor apply extension.")
        nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
        ext_modules.append(
194
195
196
197
198
199
            CUDAExtension(name='amp_C',
                          sources=['csrc/amp_C_frontend.cpp',
                                   'csrc/multi_tensor_sgd_kernel.cu',
                                   'csrc/multi_tensor_scale_kernel.cu',
                                   'csrc/multi_tensor_axpby_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu',
200
                                   'csrc/multi_tensor_l2norm_kernel_mp.cu',
201
                                   'csrc/multi_tensor_l2norm_scale_kernel.cu',
202
203
204
205
206
                                   'csrc/multi_tensor_lamb_stage_1.cu',
                                   'csrc/multi_tensor_lamb_stage_2.cu',
                                   'csrc/multi_tensor_adam.cu',
                                   'csrc/multi_tensor_adagrad.cu',
                                   'csrc/multi_tensor_novograd.cu',
207
208
                                   'csrc/multi_tensor_lamb.cu',
                                   'csrc/multi_tensor_lamb_mp.cu'],
209
                          include_dirs=[os.path.join(this_dir, 'csrc')],
210
211
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
212

lcskrishna's avatar
lcskrishna committed
213
        print ("INFO: Building syncbn extension.")
214
        ext_modules.append(
215
216
217
            CUDAExtension(name='syncbn',
                          sources=['csrc/syncbn.cpp',
                                   'csrc/welford.cu'],
218
                          include_dirs=[os.path.join(this_dir, 'csrc')],
219
220
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
221

222
        nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
223
224
225
        hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
        print ("INFO: Building fused layernorm extension.")
        ext_modules.append(
226
227
228
            CUDAExtension(name='fused_layer_norm_cuda',
                          sources=['csrc/layer_norm_cuda.cpp',
                                   'csrc/layer_norm_cuda_kernel.cu'],
229
                          include_dirs=[os.path.join(this_dir, 'csrc')],
230
231
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
232

233
234
        print ("INFO: Building the MLP Extension.")
        ext_modules.append(
235
236
237
            CUDAExtension(name='mlp_cuda',
                          sources=['csrc/mlp.cpp',
                                   'csrc/mlp_cuda.cu'],
238
                          include_dirs=[os.path.join(this_dir, 'csrc')],
239
240
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
241
242
243
244
245
246
        ext_modules.append(
            CUDAExtension(name='fused_dense_cuda',
                          sources=['csrc/fused_dense.cpp',
                                   'csrc/fused_dense_cuda.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
247
        """
Masaki Kozuki's avatar
Masaki Kozuki committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        ext_modules.append(
            CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
                                   'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))

        ext_modules.append(
            CUDAExtension(name='scaled_masked_softmax_cuda',
                          sources=['csrc/megatron/scaled_masked_softmax.cpp',
                                   'csrc/megatron/scaled_masked_softmax_cuda.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda'] + version_dependent_macros}))
271
        """
272

273
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
jjsjann123's avatar
jjsjann123 committed
274
    from torch.utils.cpp_extension import CUDAExtension
275
276
    if "--bnp" in sys.argv:
        sys.argv.remove("--bnp")
jjsjann123's avatar
jjsjann123 committed
277
278
279
280

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

281
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
282
        raise RuntimeError("--bnp was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
jjsjann123's avatar
jjsjann123 committed
283
284
285
286
287
288
289
    else:
        ext_modules.append(
            CUDAExtension(name='bnp',
                          sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
                                   'apex/contrib/csrc/groupbn/ipc.cu',
                                   'apex/contrib/csrc/groupbn/interface.cpp',
                                   'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
290
291
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
mcarilli's avatar
mcarilli committed
292
                          extra_compile_args={'cxx': [] + version_dependent_macros,
jjsjann123's avatar
jjsjann123 committed
293
294
295
                                              'nvcc':['-DCUDA_HAS_FP16=1',
                                                      '-D__CUDA_NO_HALF_OPERATORS__',
                                                      '-D__CUDA_NO_HALF_CONVERSIONS__',
296
                                                      '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
jjsjann123's avatar
jjsjann123 committed
297

298
if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
299
    from torch.utils.cpp_extension import CUDAExtension
300
301
    if "--xentropy" in sys.argv:
        sys.argv.remove("--xentropy")
302
303
304
305

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

306
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
307
308
        raise RuntimeError("--xentropy was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
309
310
        print ("INFO: Building the xentropy extension.")
        ext_modules.append(
311
312
313
            CUDAExtension(name='xentropy_cuda',
                          sources=['apex/contrib/csrc/xentropy/interface.cpp',
                                   'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
314
315
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/xentropy')],
316
317
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))
318

319

320
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
321
    from torch.utils.cpp_extension import CUDAExtension
322
323
    if "--deprecated_fused_adam" in sys.argv:
        sys.argv.remove("--deprecated_fused_adam")
324
325
326
327

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

328
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
329
330
        raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
331
332
333
334
        print ("INFO: Building deprecated fused adam extension.")
        nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
        ext_modules.append(
335
336
337
            CUDAExtension(name='fused_adam_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
338
339
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/optimizers')],
340
341
342
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))

343
if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
344
    from torch.utils.cpp_extension import CUDAExtension
345
346
    if "--deprecated_fused_lamb" in sys.argv:
        sys.argv.remove("--deprecated_fused_lamb")
347
348
349
350

    from torch.utils.cpp_extension import BuildExtension
    cmdclass['build_ext'] = BuildExtension

351
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
352
353
        raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
354
355
356
357
        print ("INFO: Building deprecated fused lamb extension.")
        nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
        hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
        ext_modules.append(
358
359
360
361
362
363
            CUDAExtension(name='fused_lamb_cuda',
                          sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
                                   'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
364

365
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
ptrblck's avatar
ptrblck committed
366
367
368
369
370
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
    generator_flag = ['-DOLD_GENERATOR']

yjk21's avatar
yjk21 committed
371
372
373
if "--fast_layer_norm" in sys.argv:
    sys.argv.remove("--fast_layer_norm")

374
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
375
376
377
378
        raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
379
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        if int(bare_metal_major) >= 11:
            cc_flag.append('-gencode')
            cc_flag.append('arch=compute_80,code=sm_80')

        ext_modules.append(
            CUDAExtension(name='fast_layer_norm',
                          sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
                                   'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
                                   'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_70,code=sm_70',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
395
396
397
398
399
                                                      '-U__CUDA_NO_BFLOAT16_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
                                                      '-U__CUDA_NO_BFLOAT162_OPERATORS__',
                                                      '-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
                                                      '-I./apex/contrib/csrc/layer_norm/',
yjk21's avatar
yjk21 committed
400
401
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
402
403
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
yjk21's avatar
yjk21 committed
404
405
406
if "--fmha" in sys.argv:
    sys.argv.remove("--fmha")

407
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
yjk21's avatar
yjk21 committed
408
409
410
411
        raise RuntimeError("--fmha was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
Masaki Kozuki's avatar
Masaki Kozuki committed
412
        _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
yjk21's avatar
yjk21 committed
413
414
415
416
417
418
419
        if int(bare_metal_major) < 11:
            raise RuntimeError("--fmha only supported on SM80")

        ext_modules.append(
            CUDAExtension(name='fmhalib',
                          sources=[
                                   'apex/contrib/csrc/fmha/fmha_api.cpp',
yjk21's avatar
yjk21 committed
420
                                   'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
yjk21's avatar
yjk21 committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
                                   'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
                                   ],
                          extra_compile_args={'cxx': ['-O3',
                                                      ] + version_dependent_macros + generator_flag,
                                              'nvcc':['-O3',
                                                      '-gencode', 'arch=compute_80,code=sm_80',
                                                      '-U__CUDA_NO_HALF_OPERATORS__',
                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
                                                      '--expt-relaxed-constexpr',
                                                      '--expt-extended-lambda',
Masaki Kozuki's avatar
Masaki Kozuki committed
438
439
                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
                          include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
yjk21's avatar
yjk21 committed
440

ptrblck's avatar
ptrblck committed
441

442
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
443
    from torch.utils.cpp_extension import CUDAExtension
444
445
    if "--fast_multihead_attn" in sys.argv:
        sys.argv.remove("--fast_multihead_attn")
446
447

    from torch.utils.cpp_extension import BuildExtension
448
    cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
449

450
    if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
451
452
        raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
ptrblck's avatar
ptrblck committed
453
454
        # Check, if CUDA11 is installed for compute capability 8.0
        cc_flag = []
455
456
457
458
459
        if not IS_ROCM_PYTORCH:
            _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
            if int(bare_metal_major) >= 11:
                cc_flag.append('-gencode')
                cc_flag.append('arch=compute_80,code=sm_80')
ptrblck's avatar
ptrblck committed
460

461
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        nvcc_args_mha = ['-O3',
                         '-gencode',
                         'arch=compute_70,code=sm_70',
                         '-Iapex/contrib/csrc/multihead_attn/cutlass',
                         '-U__CUDA_NO_HALF_OPERATORS__',
                         '-U__CUDA_NO_HALF_CONVERSIONS__',
                         '--expt-relaxed-constexpr',
                         '--expt-extended-lambda',
                         '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
        hipcc_args_mha = ['-O3',
                          '-Iapex/contrib/csrc/multihead_attn/cutlass',
                          '-I/opt/rocm/include/hiprand',
                          '-I/opt/rocm/include/rocrand',
                          '-U__HIP_NO_HALF_OPERATORS__',
                          '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
477

478
479
        ext_modules.append(
            CUDAExtension(name='fast_additive_mask_softmax_dropout',
480
                          sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp',
481
                                   'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
482
483
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
484
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
485
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
486
487
        ext_modules.append(
            CUDAExtension(name='fast_mask_softmax_dropout',
488
                          sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp',
489
                                   'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
490
491
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
492
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
493
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
494
495
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
496
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp',
497
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
498
499
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
500
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
501
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
502
503
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_bias',
504
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp',
505
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
506
507
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
508
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
509
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
510
511
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn',
512
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp',
513
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
514
515
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
516
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
517
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
518
519
        ext_modules.append(
            CUDAExtension(name='fast_self_multihead_attn_norm_add',
520
                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp',
521
                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
522
523
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
524
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
525
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
526
527
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn',
528
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp',
529
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
530
531
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
532
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
533
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
534
535
        ext_modules.append(
            CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
536
                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp',
537
                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
538
539
                          include_dirs=[os.path.join(this_dir, 'csrc'),
                                        os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')],
ptrblck's avatar
ptrblck committed
540
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
541
                                              'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
542

543
544
545
if "--transducer" in sys.argv:
    sys.argv.remove("--transducer")

546
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
547
548
549
550
551
552
553
        raise RuntimeError("--transducer was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        ext_modules.append(
            CUDAExtension(name='transducer_joint_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
                                   'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
Masaki Kozuki's avatar
Masaki Kozuki committed
554
555
                                              'nvcc': ['-O3'] + version_dependent_macros},
                          include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
556
557
558
559
560
561
562
563
        ext_modules.append(
            CUDAExtension(name='transducer_loss_cuda',
                          sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
                                   'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
                          include_dirs=[os.path.join(this_dir, 'csrc')],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-O3'] + version_dependent_macros}))

564
565
566
if "--fast_bottleneck" in sys.argv:
    sys.argv.remove("--fast_bottleneck")

567
    if CUDA_HOME is None and not IS_ROCM_PYTORCH:
568
569
570
571
572
573
        raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
    else:
        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
        ext_modules.append(
            CUDAExtension(name='fast_bottleneck',
                          sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
Masaki Kozuki's avatar
Masaki Kozuki committed
574
                          include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
575
576
                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))

577
578
if "--cuda_ext" in sys.argv:
    sys.argv.remove("--cuda_ext")
579

Christian Sarofeen's avatar
Christian Sarofeen committed
580
setup(
581
582
    name='apex',
    version='0.1',
583
584
585
586
    packages=find_packages(exclude=('build',
                                    'csrc',
                                    'include',
                                    'tests',
587
588
589
590
591
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
Christian Sarofeen's avatar
Christian Sarofeen committed
592
    description='PyTorch Extensions written by NVIDIA',
jjsjann123's avatar
jjsjann123 committed
593
    ext_modules=ext_modules,
594
595
    cmdclass=cmdclass,
    #cmdclass={'build_ext': BuildExtension} if ext_modules else {},
ptrblck's avatar
ptrblck committed
596
    extras_require=extras,
Christian Sarofeen's avatar
Christian Sarofeen committed
597
)