Commit 9015e384 authored by wooway777's avatar wooway777
Browse files

issue/1033 - fix nvidia lua

parent e6b6fba5
...@@ -149,9 +149,10 @@ target("flash-attn-nvidia") ...@@ -149,9 +149,10 @@ target("flash-attn-nvidia")
if FLASH_ATTN_ROOT ~= nil then if FLASH_ATTN_ROOT ~= nil then
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
local PYTHON_LIB_DIR= os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
local LIB_PYTHON = os.iorunv("python", {"-c", "import sysconfig, os; print(sysconfig.get_config_var('LDLIBRARY'))"}):trim() local LIB_PYTHON = os.iorunv("python", {"-c", "import sysconfig, os; print(sysconfig.get_config_var('LDLIBRARY'))"}):trim()
-- Include dirs
-- Include dirs (needed for both device and host)
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false}) target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false}) target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
target:add("includedirs", TORCH_DIR .. "/include", {public = false}) target:add("includedirs", TORCH_DIR .. "/include", {public = false})
...@@ -159,19 +160,25 @@ target("flash-attn-nvidia") ...@@ -159,19 +160,25 @@ target("flash-attn-nvidia")
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false}) target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false}) target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})
-- Link libraries -- For device linking, only add CUDA-related link directories
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR) target:add("linkdirs", TORCH_DIR .. "/lib", {public = false, force = true})
target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", LIB_PYTHON)
-- For host linking, we need to add these via link options
-- Use add_ldflags to pass library paths to the host linker only
target:add("ldflags", "-L" .. TORCH_DIR .. "/lib", {force = true})
target:add("ldflags", "-L" .. PYTHON_LIB_DIR, {force = true})
target:add("ldflags", "-l" .. LIB_PYTHON:gsub("%.so$", ""):gsub("^lib", ""), {force = true})
target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python")
end end
end) end)
if FLASH_ATTN_ROOT ~= nil then if FLASH_ATTN_ROOT ~= nil then
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp") add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu") add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
-- Link options -- Link options
add_ldflags("-Wl,--no-undefined", {force = true}) add_ldflags("-Wl,--no-undefined", {force = true})
-- Compile options -- Compile options
add_cxflags("-fPIC", {force = true}) add_cxflags("-fPIC", {force = true})
add_cuflags("-Xcompiler=-fPIC") add_cuflags("-Xcompiler=-fPIC")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment