setup.py 1.17 KB
Newer Older
1
2
import os
import subprocess
3
from packaging.version import parse, Version
4
5
6
7
8
9
10
11
12
13

import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME


def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
14
    bare_metal_version = parse(output[release_idx].split(",")[0])
15

16
    return raw_output, bare_metal_version
17
18
19


def append_nvcc_threads(nvcc_extra_args):
20
21
    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
    if bare_metal_version >= Version("11.2"):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        return nvcc_extra_args + ["--threads", "4"]
    return nvcc_extra_args


setup(
    name='fused_dense_lib',
    ext_modules=[
        CUDAExtension(
            name='fused_dense_lib',
            sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
            extra_compile_args={
                               'cxx': ['-O3',],
                               'nvcc': append_nvcc_threads(['-O3'])
                               }
            )
    ],
    cmdclass={
        'build_ext': BuildExtension
})