setup.py 1.43 KB
Newer Older
1
import setuptools
Rick Ho's avatar
Rick Ho committed
2
3
4
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os

Rick Ho's avatar
Rick Ho committed
5

6
cxx_flags = []
Rick Ho's avatar
Rick Ho committed
7
ext_libs = []
8

Rick Ho's avatar
Rick Ho committed
9
10
11
12
13
14
15
16
17
authors = [
        'Jiaao He', 
        'Jiezhong Qiu', 
        'Aohan Zeng', 
        'Tiago Antunes', 
        'Jinjun Peng', 
        'Qin Li',
]

Rick Ho's avatar
Rick Ho committed
18
if os.environ.get('USE_NCCL', '0') == '1':
Rick Ho's avatar
Rick Ho committed
19
    cxx_flags.append('-DFMOE_USE_NCCL')
20
    ext_libs.append('nccl')
Rick Ho's avatar
Rick Ho committed
21

Rick Ho's avatar
Rick Ho committed
22

23
24
if __name__ == '__main__':
    setuptools.setup(
Rick Ho's avatar
Rick Ho committed
25
        name='fastmoe',
Rick Ho's avatar
Rick Ho committed
26
        version='0.2.0',
Rick Ho's avatar
Rick Ho committed
27
        description='An efficient Mixture-of-Experts system for PyTorch',
Rick Ho's avatar
Rick Ho committed
28
        author=', '.join(authors),
Rick Ho's avatar
Rick Ho committed
29
30
31
        author_email='hja20@mails.tsinghua.edu.cn',
        license='Apache-2',
        url='https://github.com/laekov/fastmoe',
Rick Ho's avatar
Rick Ho committed
32
        packages=['fmoe', 'fmoe.megatron', 'fmoe.gates'],
33
34
35
36
        ext_modules=[
            CUDAExtension(
                name='fmoe_cuda', 
                sources=[
Rick Ho's avatar
Rick Ho committed
37
38
                    'cuda/stream_manager.cpp',
                    'cuda/local_exchange.cu',
Rick Ho's avatar
Rick Ho committed
39
                    'cuda/balancing.cu',
Rick Ho's avatar
Rick Ho committed
40
                    'cuda/global_exchange.cpp',
41
                    'cuda/parallel_linear.cu',
Rick Ho's avatar
Rick Ho committed
42
                    'cuda/fmoe_cuda.cpp',
43
44
45
46
                    ],
                extra_compile_args={
                    'cxx': cxx_flags,
                    'nvcc': cxx_flags
Rick Ho's avatar
Rick Ho committed
47
48
                    },
                libraries=ext_libs
49
50
51
52
53
                )
            ],
        cmdclass={
            'build_ext': BuildExtension
        })