nvidia.lua 6.71 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
4
5
local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH")
if CUDNN_ROOT ~= nil then
    add_includedirs(CUDNN_ROOT .. "/include")
end

6
7
8
9
10
11
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")

if CUTLASS_ROOT ~= nil then
    add_includedirs(CUTLASS_ROOT)
end

12
13
14
15
local FLASH_ATTN_ROOT = get_config("flash-attn")

local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")

16
target("infiniop-nvidia")
PanZezhongQY's avatar
PanZezhongQY committed
17
    set_kind("static")
PanZezhong's avatar
PanZezhong committed
18
    add_deps("infini-utils")
PanZezhongQY's avatar
PanZezhongQY committed
19
20
    on_install(function (target) end)

21
    set_policy("build.cuda.devlink", true)
PanZezhongQY's avatar
PanZezhongQY committed
22
    set_toolchains("cuda")
23
    add_links("cudart", "cublas")
24
25
26
    if has_config("cudnn") then
        add_links("cudnn")
    end
PanZezhongQY's avatar
PanZezhongQY committed
27

28
29
30
31
    on_load(function (target)
        import("lib.detect.find_tool")
        local nvcc = find_tool("nvcc")
        if nvcc ~= nil then
pengcheng888's avatar
pengcheng888 committed
32
33
34
35
36
37
38
            if is_plat("windows") then
                nvcc_path = os.iorun("where nvcc"):match("(.-)\r?\n")
            else
                nvcc_path = nvcc.program
            end

            target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
39
            target:add("links", "cuda")
40
41
42
        end
    end)

PanZezhongQY's avatar
PanZezhongQY committed
43
44
    if is_plat("windows") then
        add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
45
        add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX")
46
        add_cxxflags("/FS")
PanZezhongQY's avatar
PanZezhongQY committed
47
48
49
50
        if CUDNN_ROOT ~= nil then
            add_linkdirs(CUDNN_ROOT .. "\\lib\\x64")
        end
    else
51
        add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
PanZezhongQY's avatar
PanZezhongQY committed
52
        add_cuflags("-Xcompiler=-fPIC")
53
        add_cuflags("--extended-lambda")
PanZezhongQY's avatar
PanZezhongQY committed
54
        add_culdflags("-Xcompiler=-fPIC")
55
        add_cxflags("-fPIC")
PanZezhongQY's avatar
PanZezhongQY committed
56
        add_cxxflags("-fPIC")
57
        add_cflags("-fPIC")
PanZezhong's avatar
PanZezhong committed
58
59
60
61
        add_cuflags("--expt-relaxed-constexpr")
        if CUDNN_ROOT ~= nil then
            add_linkdirs(CUDNN_ROOT .. "/lib")
        end
PanZezhongQY's avatar
PanZezhongQY committed
62
63
    end

64
    add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function")
65

66
67
68
69
70
71
72
73
74
75
76
    local arch_opt = get_config("cuda_arch")
    if arch_opt and type(arch_opt) == "string" then
        for _, arch in ipairs(arch_opt:split(",")) do
            arch = arch:trim()
            local compute = arch:gsub("sm_", "compute_")
            add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch)
        end
    else
        add_cugencodes("native")
    end

PanZezhongQY's avatar
PanZezhongQY committed
77
    set_languages("cxx17")
qinyiqun's avatar
qinyiqun committed
78
    add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu", "../src/infiniop/ops/*/*/nvidia/*.cu")
79
80

    if has_config("ninetoothed") then
81
        add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp")
82
    end
PanZezhongQY's avatar
PanZezhongQY committed
83
target_end()
84

85
target("infinirt-nvidia")
86
    set_kind("static")
PanZezhong's avatar
PanZezhong committed
87
    add_deps("infini-utils")
88
    on_install(function (target) end)
89

90
    set_policy("build.cuda.devlink", true)
91
92
    set_toolchains("cuda")
    add_links("cudart")
93
94
95

    if is_plat("windows") then
        add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
96
        add_cxxflags("/FS")
97
98
99
100
    else
        add_cuflags("-Xcompiler=-fPIC")
        add_culdflags("-Xcompiler=-fPIC")
        add_cxflags("-fPIC")
101
        add_cxxflags("-fPIC")
102
103
104
105
106
    end

    set_languages("cxx17")
    add_files("../src/infinirt/cuda/*.cu")
target_end()
107

108
target("infiniccl-nvidia")
109
110
111
112
113
114
115
116
117
118
119
120
    set_kind("static")
    add_deps("infinirt")
    on_install(function (target) end)
    if has_config("ccl") then
        set_policy("build.cuda.devlink", true)
        set_toolchains("cuda")
        add_links("cudart")

        if not is_plat("windows") then
            add_cuflags("-Xcompiler=-fPIC")
            add_culdflags("-Xcompiler=-fPIC")
            add_cxflags("-fPIC")
121
            add_cxxflags("-fPIC")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

            local nccl_root = os.getenv("NCCL_ROOT")
            if nccl_root then
                add_includedirs(nccl_root .. "/include")
                add_links(nccl_root .. "/lib/libnccl.so")
            else
                add_links("nccl") -- Fall back to default nccl linking
            end

            add_files("../src/infiniccl/cuda/*.cu")
        else
            print("[Warning] NCCL is not supported on Windows")
        end
    end
    set_languages("cxx17")
137

138
target_end()
139
140
141
142
143
144
145
146
147

target("flash-attn-nvidia")
    set_kind("shared")
    set_default(false)
    set_policy("build.cuda.devlink", true)
    set_toolchains("cuda")
    add_links("cudart")
    add_cugencodes("native")

148
149
    if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= false and FLASH_ATTN_ROOT ~= "" then
        before_build(function (target)
150
151
            local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
            local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
wooway777's avatar
wooway777 committed
152
            local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
PanZezhong's avatar
PanZezhong committed
153
            local LIB_PYTHON = os.iorunv("python", {"-c", "import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])"}):trim()
wooway777's avatar
wooway777 committed
154
155
            
            -- Include dirs (needed for both device and host)
156
157
158
159
160
161
162
            target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
            target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
            target:add("includedirs", TORCH_DIR .. "/include", {public = false})
            target:add("includedirs", PYTHON_INCLUDE, {public = false})
            target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
            target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})

PanZezhong's avatar
PanZezhong committed
163
164
165
            -- Link libraries
            target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR)
            target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", LIB_PYTHON)
166
        end)
PanZezhong's avatar
PanZezhong committed
167

168
169
        add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
        add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
wooway777's avatar
wooway777 committed
170
        
171
172
        -- Link options
        add_ldflags("-Wl,--no-undefined", {force = true})
wooway777's avatar
wooway777 committed
173
        
174
175
176
177
178
        -- Compile options
        add_cxflags("-fPIC", {force = true})
        add_cuflags("-Xcompiler=-fPIC")
        add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true})
        set_values("cuda.rdc", false)
179
180
181
182
183
    else
        -- If flash-attn is not available, just create an empty target
        before_build(function (target)
            print("Flash Attention not available, skipping flash-attn-nvidia build")
        end)
184
185
186
187
188
    end

    on_install(function (target) end)

target_end()