Commit fc2500e5 authored by wooway777's avatar wooway777
Browse files

issue/1033 - further fix nv lua for backward compatibility

parent e7a1b121
...@@ -235,14 +235,14 @@ option_end() ...@@ -235,14 +235,14 @@ option_end()
-- Flash-Attn -- Flash-Attn
option("flash-attn") option("flash-attn")
set_default(nil) set_default("")
set_showmenu(true) set_showmenu(true)
set_description("Path to flash-attention repo. If not set, flash-attention will not used.") set_description("Path to flash-attention repo. If not set, flash-attention will not used.")
option_end() option_end()
if has_config("aten") then if has_config("aten") then
add_defines("ENABLE_ATEN") add_defines("ENABLE_ATEN")
if get_config("flash-attn") ~= nil then if get_config("flash-attn") ~= false then
add_defines("ENABLE_FLASH_ATTN") add_defines("ENABLE_FLASH_ATTN")
end end
end end
...@@ -462,7 +462,7 @@ target("infinicore_cpp_api") ...@@ -462,7 +462,7 @@ target("infinicore_cpp_api")
add_linkdirs(INFINI_ROOT.."/lib") add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl") add_links("infiniop", "infinirt", "infiniccl")
if get_config("flash-attn") == true then if get_config("flash-attn") ~= "" then
add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"}) add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"})
if has_config("nv-gpu") then if has_config("nv-gpu") then
add_deps("flash-attn-nvidia") add_deps("flash-attn-nvidia")
......
...@@ -145,7 +145,7 @@ target("flash-attn-nvidia") ...@@ -145,7 +145,7 @@ target("flash-attn-nvidia")
add_links("cudart") add_links("cudart")
add_cugencodes("native") add_cugencodes("native")
if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= false and FLASH_ATTN_ROOT ~= "" then if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
before_build(function (target) before_build(function (target)
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment