nvidia.lua 6.36 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
4
5
local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH")
if CUDNN_ROOT ~= nil then
    add_includedirs(CUDNN_ROOT .. "/include")
end

6
7
8
9
10
11
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")

if CUTLASS_ROOT ~= nil then
    add_includedirs(CUTLASS_ROOT)
end

12
13
14
15
local FLASH_ATTN_ROOT = get_config("flash-attn")

local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")

16
target("infiniop-nvidia")
PanZezhongQY's avatar
PanZezhongQY committed
17
    set_kind("static")
PanZezhong's avatar
PanZezhong committed
18
    add_deps("infini-utils")
PanZezhongQY's avatar
PanZezhongQY committed
19
20
    on_install(function (target) end)

21
    set_policy("build.cuda.devlink", true)
PanZezhongQY's avatar
PanZezhongQY committed
22
    set_toolchains("cuda")
23
    add_links("cudart", "cublas")
24
25
26
    if has_config("cudnn") then
        add_links("cudnn")
    end
PanZezhongQY's avatar
PanZezhongQY committed
27

28
29
30
31
    on_load(function (target)
        import("lib.detect.find_tool")
        local nvcc = find_tool("nvcc")
        if nvcc ~= nil then
pengcheng888's avatar
pengcheng888 committed
32
33
34
35
36
37
38
            if is_plat("windows") then
                nvcc_path = os.iorun("where nvcc"):match("(.-)\r?\n")
            else
                nvcc_path = nvcc.program
            end

            target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
39
            target:add("links", "cuda")
40
41
42
        end
    end)

PanZezhongQY's avatar
PanZezhongQY committed
43
44
    if is_plat("windows") then
        add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
45
        add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX")
46
        add_cxxflags("/FS")
PanZezhongQY's avatar
PanZezhongQY committed
47
48
49
50
        if CUDNN_ROOT ~= nil then
            add_linkdirs(CUDNN_ROOT .. "\\lib\\x64")
        end
    else
51
        add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
PanZezhongQY's avatar
PanZezhongQY committed
52
        add_cuflags("-Xcompiler=-fPIC")
53
        add_cuflags("--extended-lambda")
PanZezhongQY's avatar
PanZezhongQY committed
54
        add_culdflags("-Xcompiler=-fPIC")
55
        add_cxflags("-fPIC")
PanZezhongQY's avatar
PanZezhongQY committed
56
        add_cxxflags("-fPIC")
57
        add_cflags("-fPIC")
PanZezhong's avatar
PanZezhong committed
58
59
60
61
        add_cuflags("--expt-relaxed-constexpr")
        if CUDNN_ROOT ~= nil then
            add_linkdirs(CUDNN_ROOT .. "/lib")
        end
PanZezhongQY's avatar
PanZezhongQY committed
62
63
    end

64
    add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function")
65

66
67
68
69
70
71
72
73
74
75
76
    local arch_opt = get_config("cuda_arch")
    if arch_opt and type(arch_opt) == "string" then
        for _, arch in ipairs(arch_opt:split(",")) do
            arch = arch:trim()
            local compute = arch:gsub("sm_", "compute_")
            add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch)
        end
    else
        add_cugencodes("native")
    end

PanZezhongQY's avatar
PanZezhongQY committed
77
    set_languages("cxx17")
qinyiqun's avatar
qinyiqun committed
78
    add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu", "../src/infiniop/ops/*/*/nvidia/*.cu")
79
80

    if has_config("ninetoothed") then
81
        add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp")
82
    end
PanZezhongQY's avatar
PanZezhongQY committed
83
target_end()
84

85
target("infinirt-nvidia")
86
    set_kind("static")
PanZezhong's avatar
PanZezhong committed
87
    add_deps("infini-utils")
88
    on_install(function (target) end)
89

90
    set_policy("build.cuda.devlink", true)
91
92
    set_toolchains("cuda")
    add_links("cudart")
93
94
95

    if is_plat("windows") then
        add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
96
        add_cxxflags("/FS")
97
98
99
100
    else
        add_cuflags("-Xcompiler=-fPIC")
        add_culdflags("-Xcompiler=-fPIC")
        add_cxflags("-fPIC")
101
        add_cxxflags("-fPIC")
102
103
104
105
106
    end

    set_languages("cxx17")
    add_files("../src/infinirt/cuda/*.cu")
target_end()
107

108
target("infiniccl-nvidia")
109
110
111
112
113
114
115
116
117
118
119
120
    set_kind("static")
    add_deps("infinirt")
    on_install(function (target) end)
    if has_config("ccl") then
        set_policy("build.cuda.devlink", true)
        set_toolchains("cuda")
        add_links("cudart")

        if not is_plat("windows") then
            add_cuflags("-Xcompiler=-fPIC")
            add_culdflags("-Xcompiler=-fPIC")
            add_cxflags("-fPIC")
121
            add_cxxflags("-fPIC")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

            local nccl_root = os.getenv("NCCL_ROOT")
            if nccl_root then
                add_includedirs(nccl_root .. "/include")
                add_links(nccl_root .. "/lib/libnccl.so")
            else
                add_links("nccl") -- Fall back to default nccl linking
            end

            add_files("../src/infiniccl/cuda/*.cu")
        else
            print("[Warning] NCCL is not supported on Windows")
        end
    end
    set_languages("cxx17")
137

138
target_end()
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

target("flash-attn-nvidia")
    set_kind("shared")
    set_default(false)
    set_policy("build.cuda.devlink", true)
    set_toolchains("cuda")
    add_links("cudart")
    add_cugencodes("native")

    before_build(function (target)
        if FLASH_ATTN_ROOT ~= nil then
            local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
            local TORCH_DIR = outdata

            local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
            local PYTHON_INCLUDE = outdata

            local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
            local PYTHON_LIB_DIR = outdata
            -- Include dirs
            target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
            target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
            target:add("includedirs", TORCH_DIR .. "/include", {public = false})
            target:add("includedirs", PYTHON_INCLUDE, {public = false})
            target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
            target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})

            -- Link libraries
            target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR)
            target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", "python3")

        end

    end)

    if FLASH_ATTN_ROOT ~= nil then
        add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
        add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
        -- Link options
        add_ldflags("-Wl,--no-undefined", {force = true})
        -- Compile options
        add_cxflags("-fPIC", {force = true})
        add_cuflags("-Xcompiler=-fPIC")
        add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true})
        set_values("cuda.rdc", false)
    end

    on_install(function (target) end)

target_end()