setup.py 4.7 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
import os
import subprocess
import setuptools
4
import importlib
5
6

from pathlib import Path
Chenggang Zhao's avatar
Chenggang Zhao committed
7
8
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

Chenggang Zhao's avatar
Chenggang Zhao committed
9
10

# Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X`
11
12
13
14
def get_nvshmem_host_lib_name(base_dir):
    path = Path(base_dir).joinpath('lib')
    for file in path.rglob('libnvshmem_host.so.*'):
        return file.name
15
    raise ModuleNotFoundError('libnvshmem_host.so not found')
Chenggang Zhao's avatar
Chenggang Zhao committed
16

Chenggang Zhao's avatar
Chenggang Zhao committed
17

Chenggang Zhao's avatar
Chenggang Zhao committed
18
if __name__ == '__main__':
19
    disable_nvshmem = False
Chenggang Zhao's avatar
Chenggang Zhao committed
20
    nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
21
22
23
24
    nvshmem_host_lib = 'libnvshmem_host.so'
    if nvshmem_dir is None:
        try:
            nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
25
            nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir)
26
27
28
29
            import nvidia.nvshmem as nvshmem
        except (ModuleNotFoundError, AttributeError, IndexError):
            print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n')
            disable_nvshmem = True
30
    else:
31
32
33
34
35
        disable_nvshmem = False

    if not disable_nvshmem:
        assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'

Chenggang Zhao's avatar
Chenggang Zhao committed
36
37
    cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
                 '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    nvcc_flags = ['-O3', '-Xcompiler', '-O3']
    sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu']
    include_dirs = ['csrc/']
    library_dirs = []
    nvcc_dlink = []
    extra_link_args = []

    # NVSHMEM flags
    if disable_nvshmem:
        cxx_flags.append('-DDISABLE_NVSHMEM')
        nvcc_flags.append('-DDISABLE_NVSHMEM')
    else:
        sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])
        include_dirs.extend([f'{nvshmem_dir}/include'])
        library_dirs.extend([f'{nvshmem_dir}/lib'])
53
        nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])
54
        extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
        # Prefer A100
        os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0')

        # Disable some SM90 features: FP8, launch methods, and TMA
        cxx_flags.append('-DDISABLE_SM90_FEATURES')
        nvcc_flags.append('-DDISABLE_SM90_FEATURES')

        # Disable internode and low-latency kernels
        assert disable_nvshmem
    else:
        # Prefer H800 series
        os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')

        # CUDA 12 flags
        nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])
Chenggang Zhao's avatar
Chenggang Zhao committed
72

73
74
75
76
77
    # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
    if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0':
        assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
        os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'

Chenggang Zhao's avatar
Chenggang Zhao committed
78
    # Disable aggressive PTX instructions
79
    if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):
Chenggang Zhao's avatar
Chenggang Zhao committed
80
81
82
        cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
        nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')

83
    # Put them together
Chenggang Zhao's avatar
Chenggang Zhao committed
84
85
86
87
    extra_compile_args = {
        'cxx': cxx_flags,
        'nvcc': nvcc_flags,
    }
88
89
90
91
92
93
94
95
96
97
98
99
100
    if len(nvcc_dlink) > 0:
        extra_compile_args['nvcc_dlink'] = nvcc_dlink

    # Summary
    print(f'Build summary:')
    print(f' > Sources: {sources}')
    print(f' > Includes: {include_dirs}')
    print(f' > Libraries: {library_dirs}')
    print(f' > Compilation flags: {extra_compile_args}')
    print(f' > Link flags: {extra_link_args}')
    print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}')
    print(f' > NVSHMEM path: {nvshmem_dir}')
    print()
Chenggang Zhao's avatar
Chenggang Zhao committed
101
102
103
104
105
106
107
108
109
110

    # noinspection PyBroadException
    try:
        cmd = ['git', 'rev-parse', '--short', 'HEAD']
        revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
    except Exception as _:
        revision = ''

    setuptools.setup(
        name='deep_ep',
111
        version='1.1.0' + revision,
Chenggang Zhao's avatar
Chenggang Zhao committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        packages=setuptools.find_packages(
            include=['deep_ep']
        ),
        ext_modules=[
            CUDAExtension(
                name='deep_ep_cpp',
                include_dirs=include_dirs,
                library_dirs=library_dirs,
                sources=sources,
                extra_compile_args=extra_compile_args,
                extra_link_args=extra_link_args
            )
        ],
        cmdclass={
            'build_ext': BuildExtension
        }
    )