Unverified Commit dce99862 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1053 from InfiniTensor/issue/1033xmake

Issue/1033 patch aten and fa adaptations
parents 8d99a8f5 d6e44e84
......@@ -226,6 +226,28 @@ if has_config("ninetoothed") then
add_defines("ENABLE_NINETOOTHED")
end
-- ATen
option("aten")
set_default(false)
set_showmenu(true)
set_description("Wether to link aten and torch libraries")
option_end()
-- Flash-Attn
option("flash-attn")
set_default(nil)
set_showmenu(true)
set_description("Path to flash-attention repo. If not set, flash-attention will not used.")
option_end()
if has_config("aten") then
add_defines("ENABLE_ATEN")
if get_config("flash-attn") ~= nil then
add_defines("ENABLE_FLASH_ATTN")
end
end
-- cuda graph
option("graph")
set_default(false)
......@@ -314,6 +336,7 @@ target("infinirt")
if not is_plat("windows") then
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
add_ldflags("-fPIC", {force = true})
end
set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini"))
add_files("src/infinirt/*.cc")
......@@ -439,8 +462,44 @@ target("infinicore_cpp_api")
add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl")
if get_config("flash-attn") ~= nil then
add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"})
if has_config("nv-gpu") then
add_deps("flash-attn-nvidia")
end
end
before_build(function (target)
if has_config("aten") then
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
target:add(
"includedirs",
path.join(TORCH_DIR, "include"),
path.join(TORCH_DIR, "include/torch/csrc/api/include"),
{ public = true })
target:add(
"linkdirs",
path.join(TORCH_DIR, "lib"),
{ public = true }
)
target:add(
"links",
"torch",
"c10",
"torch_cuda",
"c10_cuda",
{ public = true }
)
end
end)
-- Add InfiniCore C++ source files (needed for RoPE and other nn modules)
add_files("src/infinicore/*.cc")
add_files("src/infinicore/adaptor/*.cc")
add_files("src/infinicore/context/*.cc")
add_files("src/infinicore/context/*/*.cc")
add_files("src/infinicore/tensor/*.cc")
......
......@@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
end
local FLASH_ATTN_ROOT = get_config("flash-attn")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
target("infiniop-nvidia")
set_kind("static")
add_deps("infini-utils")
......@@ -132,3 +136,53 @@ target("infiniccl-nvidia")
set_languages("cxx17")
target_end()
target("flash-attn-nvidia")
set_kind("shared")
set_default(false)
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart")
add_cugencodes("native")
if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= false and FLASH_ATTN_ROOT ~= "" then
before_build(function (target)
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
local LIB_PYTHON = os.iorunv("python", {"-c", "import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])"}):trim()
-- Include dirs (needed for both device and host)
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
target:add("includedirs", TORCH_DIR .. "/include", {public = false})
target:add("includedirs", PYTHON_INCLUDE, {public = false})
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})
-- Link libraries
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR)
target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", LIB_PYTHON)
end)
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
-- Link options
add_ldflags("-Wl,--no-undefined", {force = true})
-- Compile options
add_cxflags("-fPIC", {force = true})
add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true})
set_values("cuda.rdc", false)
else
-- If flash-attn is not available, just create an empty target
before_build(function (target)
print("Flash Attention not available, skipping flash-attn-nvidia build")
end)
end
on_install(function (target) end)
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment