Commit 5de45ee6 authored by yaoht's avatar yaoht
Browse files

接入fa,适配dcu,优化addrmsnorm和rope算子

parent 93191613
Pipeline #3510 failed with stages
in 0 seconds
......@@ -15,13 +15,16 @@ from framework import (
TestCase,
)
# gfx936 (Hygon DCU) paged attention only supports page_block_size=64
_BLOCK_SIZE = 64 if "--hygon" in sys.argv else 256
# Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch])
_TEST_CASES_DATA = [
(1, 1, 128, 256, [(250,), (7,)]),
(4, 4, 128, 256, [(250,), (7,)]),
(1, 1, 128, 256, [(260, 73), (1, 1)]),
(8, 2, 128, 256, [(250,), (7,)]),
(8, 2, 128, 256, [(260, 73), (1, 1)]),
(1, 1, 128, _BLOCK_SIZE, [(250,), (7,)]),
(4, 4, 128, _BLOCK_SIZE, [(250,), (7,)]),
(1, 1, 128, _BLOCK_SIZE, [(260, 73), (1, 1)]),
(8, 2, 128, _BLOCK_SIZE, [(250,), (7,)]),
(8, 2, 128, _BLOCK_SIZE, [(260, 73), (1, 1)]),
]
_MAX_SEQUENCE_LENGTH = 8192
......
......@@ -200,6 +200,8 @@ option_end()
if has_config("hygon-dcu") then
add_defines("ENABLE_HYGON_API")
-- Required by HIP headers included from torch ATen/hip.
add_defines("__HIP_PLATFORM_AMD__")
includes("xmake/hygon.lua")
end
......@@ -240,9 +242,20 @@ option("flash-attn")
set_description("Path to flash-attention repo. If not set, flash-attention will not used.")
option_end()
option("flash-attn-prebuilt")
set_default("")
set_showmenu(true)
set_description("Path to prebuilt flash_attn .so file or directory containing it. Used for Hygon DCU.")
option_end()
if has_config("aten") then
add_defines("ENABLE_ATEN")
if get_config("flash-attn") ~= false then
local fa_src = get_config("flash-attn")
local fa_prebuilt = get_config("flash-attn-prebuilt")
if not fa_prebuilt or fa_prebuilt == "" then
fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT")
end
if (fa_src and fa_src ~= "") or (fa_prebuilt and fa_prebuilt ~= "") then
add_defines("ENABLE_FLASH_ATTN")
end
end
......@@ -469,14 +482,94 @@ target("infinicore_cpp_api")
end
end
before_build(function (target)
if has_config("hygon-dcu") then
local cuda_sdk = get_config("cuda") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH")
local dtk_root = os.getenv("DTK_ROOT") or "/opt/dtk"
local function normalize_cuda_root(root)
if not root or root == "" or not os.isdir(root) then
return nil
end
if os.isdir(path.join(root, "include")) then
return root
end
local nested = {
path.join(root, "cuda"),
path.join(root, "cuda-12")
}
for _, cand in ipairs(nested) do
if os.isdir(path.join(cand, "include")) then
return cand
end
end
return root
end
-- Prefer xmake --cuda=... for deterministic SDK include/link paths.
local normalized_cuda_sdk = normalize_cuda_root(cuda_sdk)
if normalized_cuda_sdk then
add_includedirs(path.join(normalized_cuda_sdk, "include"))
add_linkdirs(path.join(normalized_cuda_sdk, "lib64"))
end
-- Keep DTK fallback paths for environments where only DTK_ROOT is set.
if dtk_root and dtk_root ~= "" and os.isdir(dtk_root) then
add_includedirs(path.join(dtk_root, "include"))
add_includedirs(path.join(dtk_root, "cuda", "include"))
add_linkdirs(path.join(dtk_root, "lib"))
add_linkdirs(path.join(dtk_root, "cuda", "lib64"))
end
end
on_load(function (target)
if has_config("aten") then
-- Hygon DCU: link prebuilt flash_attn BEFORE torch for correct symbol resolution order
if has_config("hygon-dcu") then
local fa_prebuilt = get_config("flash-attn-prebuilt")
if not fa_prebuilt or fa_prebuilt == "" then
fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT")
end
local flash_so_dir = nil
local flash_so_name = nil
if fa_prebuilt and fa_prebuilt ~= "" then
if os.isfile(fa_prebuilt) then
flash_so_dir = path.directory(fa_prebuilt)
flash_so_name = path.filename(fa_prebuilt)
else
flash_so_dir = fa_prebuilt
local files = os.files(path.join(fa_prebuilt, "flash_attn_2_cuda*.so"))
if #files > 0 then
flash_so_name = path.filename(files[1])
end
end
else
local ok, so_path = pcall(function()
return os.iorunv("python", {"-c", "import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)"}):trim()
end)
if ok and so_path and so_path ~= "" and os.isfile(so_path) then
flash_so_dir = path.directory(so_path)
flash_so_name = path.filename(so_path)
end
end
if flash_so_dir and flash_so_name then
target:add("linkdirs", flash_so_dir)
target:add("ldflags", "-Wl,--no-as-needed", {force = true})
target:add("ldflags", "-l:" .. flash_so_name, {force = true})
target:add("ldflags", "-Wl,--as-needed", {force = true})
print("Flash Attention library: " .. path.join(flash_so_dir, flash_so_name))
end
end
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
-- Use sysincludedirs (-isystem) so that torch's bundled pybind11 headers
-- do not shadow the xmake pybind11 package headers.
target:add(
"includedirs",
path.join(TORCH_DIR, "include"),
"sysincludedirs",
path.join(TORCH_DIR, "include"),
path.join(TORCH_DIR, "include/torch/csrc/api/include"),
{ public = true })
......@@ -485,14 +578,40 @@ target("infinicore_cpp_api")
path.join(TORCH_DIR, "lib"),
{ public = true }
)
target:add(
"links",
"torch",
"c10",
"torch_cuda",
"c10_cuda",
{ public = true }
)
local torch_libdir = path.join(TORCH_DIR, "lib")
target:add("rpathdirs", torch_libdir)
target:add("ldflags", "-Wl,--no-as-needed", {force = true})
local torch_links = {"torch", "c10"}
local function has_torch_lib(name)
return #os.files(path.join(torch_libdir, "lib" .. name .. ".so*")) > 0
end
if has_torch_lib("torch_cuda") then
table.insert(torch_links, "torch_cuda")
elseif has_torch_lib("torch_hip") then
table.insert(torch_links, "torch_hip")
end
if has_torch_lib("c10_cuda") then
table.insert(torch_links, "c10_cuda")
elseif has_torch_lib("c10_hip") then
table.insert(torch_links, "c10_hip")
end
target:add("links", table.unpack(torch_links), { public = true })
-- Hard-pin runtime dependency entries to avoid linker dropping HIP torch libs.
target:add("ldflags", "-L" .. torch_libdir, {force = true})
if has_torch_lib("torch_hip") then
target:add("ldflags", "-l:libtorch_hip.so", {force = true})
end
if has_torch_lib("c10_hip") then
target:add("ldflags", "-l:libc10_hip.so", {force = true})
end
if has_torch_lib("torch_cuda") then
target:add("ldflags", "-l:libtorch_cuda.so", {force = true})
end
if has_torch_lib("c10_cuda") then
target:add("ldflags", "-l:libc10_cuda.so", {force = true})
end
target:add("ldflags", "-Wl,--as-needed", {force = true})
print("Torch libraries: " .. table.concat(torch_links, ", "))
end
end)
......@@ -515,6 +634,40 @@ target("infinicore_cpp_api")
add_installfiles("include/infinicore/(**/*.hpp)",{prefixdir = "include/infinicore"})
add_installfiles("include/infinicore.h", {prefixdir = "include"})
add_installfiles("include/infinicore.hpp", {prefixdir = "include"})
after_install(function (target)
if not has_config("hygon-dcu") then return end
local fa_prebuilt = get_config("flash-attn-prebuilt")
if not fa_prebuilt or fa_prebuilt == "" then
fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT")
end
local flash_so_path = nil
if fa_prebuilt and fa_prebuilt ~= "" then
if os.isfile(fa_prebuilt) then
flash_so_path = fa_prebuilt
else
local files = os.files(path.join(fa_prebuilt, "flash_attn_2_cuda*.so"))
if #files > 0 then flash_so_path = files[1] end
end
else
local ok, so_path = pcall(function()
return os.iorunv("python", {"-c", "import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)"}):trim()
end)
if ok and so_path and so_path ~= "" and os.isfile(so_path) then
flash_so_path = so_path
end
end
if flash_so_path then
local installdir = target:installdir()
local libdir = path.join(installdir, "lib")
os.mkdir(libdir)
os.cp(flash_so_path, libdir)
print("Copied prebuilt flash_attn library to " .. libdir)
end
end)
after_build(function (target) print(YELLOW .. "[Congratulations!] Now you can install the libraries with \"xmake install\"" .. NC) end)
target_end()
......
......@@ -64,14 +64,14 @@ target("infiniop-hygon")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936"
add_cuflags("-arch=" .. hygon_arch)
print("编译海光DCU架构: " .. hygon_arch)
print("compile hygon architecture: " .. hygon_arch)
-- 复用NVIDIA的CUDA实现,通过HIP兼容层
add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")
-- temporarily disble paged ops for hygon
remove_files("../src/infiniop/ops/paged*/nvidia/*.cu")
-- temporarily disable paged ops for hygon (segfault on gfx936, needs HIP adaptation)
-- remove_files("../src/infiniop/ops/paged*/nvidia/*.cu")
if has_config("ninetoothed") then
add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", {cxxflags = {"-Wno-return-type"}})
......@@ -107,9 +107,9 @@ target("infinirt-hygon")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936"
add_cuflags("-arch=" .. hygon_arch)
add_files("../src/infinirt/cuda/*.cu")
target_end()
......@@ -143,7 +143,7 @@ target("infiniccl-hygon")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906"
local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936"
add_cuflags("-arch=" .. hygon_arch)
-- 使用NCCL (NVIDIA Collective Communications Library)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment