Commit 06362c94 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 add flash-attn compile target

parent 515e1eca
#ifdef ENABLE_ATEN
#pragma once
#include "../context/context.hpp"
#include "../tensor.hpp"
#include <ATen/ATen.h>
#ifdef ENABLE_NVIDIA_API
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#endif
namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) {
......@@ -37,5 +40,9 @@ inline at::Device to_at_device(const Device &device) {
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream();
} // namespace infinicore::adaptor
\ No newline at end of file
#endif
} // namespace infinicore::adaptor
#endif // ENABLE_ATEN
#ifdef ENABLE_FLASH_ATTN
#pragma once
#include "aten_adaptor.hpp"
......@@ -109,4 +110,5 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);
} // namespace flash
\ No newline at end of file
} // namespace flash
#endif // ENABLE_FLASH_ATTN
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp"
namespace infinicore::adaptor {
......@@ -31,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options);
}
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
} // namespace infinicore::adaptor
\ No newline at end of file
#endif
} // namespace infinicore::adaptor
#endif // ENABLE_ATEN
......@@ -2,6 +2,8 @@
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_varlen_impl::flashattn {
struct PlannedMeta {
......@@ -38,6 +40,7 @@ void *plan(Tensor out,
}
void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
......@@ -77,6 +80,9 @@ void run(void *planned_meta) {
0.0,
false,
std::nullopt);
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
}
void cleanup(void **planned_meta_ptr) {
......
......@@ -12,8 +12,8 @@
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
......
......@@ -226,6 +226,28 @@ if has_config("ninetoothed") then
add_defines("ENABLE_NINETOOTHED")
end
-- ATen
option("aten")
set_default(false)
set_showmenu(true)
set_description("Wether to link aten and torch libraries")
option_end()
-- Flash-Attn
option("flash-attn")
set_default(nil)
set_showmenu(true)
set_description("Path to flash-attention repo. If not set, flash-attention will not used.")
option_end()
if has_config("aten") then
add_defines("ENABLE_ATEN")
if get_config("flash-attn") ~= nil then
add_defines("ENABLE_FLASH_ATTN")
end
end
-- cuda graph
option("graph")
set_default(false)
......@@ -439,31 +461,40 @@ target("infinicore_cpp_api")
add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl")
-- ==============================
-- LibTorch integration
-- ==============================
local LIBTORCH_ROOT = ("/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch")
-- headers
add_includedirs(
path.join(LIBTORCH_ROOT, "include"),
path.join(LIBTORCH_ROOT, "include/torch/csrc/api/include"),
{ public = true }
)
-- libraries
add_linkdirs(path.join(LIBTORCH_ROOT, "lib"))
-- core ATen / Torch libs
add_links(
"torch",
"c10",
"torch_cuda",
"c10_cuda"
)
-- Flash attention lib
add_linkdirs("/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build")
add_links("flash_attn")
if get_config("flash-attn") ~= nil then
add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"})
if has_config("nv-gpu") then
add_deps("flash-attn-nvidia")
end
end
before_build(function (target)
if has_config("aten") then
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
target:add(
"includedirs",
path.join(TORCH_DIR, "include"),
path.join(TORCH_DIR, "include/torch/csrc/api/include"),
{ public = true })
target:add(
"linkdirs",
path.join(TORCH_DIR, "lib"),
{ public = true }
)
target:add(
"links",
"torch",
"c10",
"torch_cuda",
"c10_cuda",
{ public = true }
)
end
end)
-- Add InfiniCore C++ source files (needed for RoPE and other nn modules)
add_files("src/infinicore/*.cc")
......
......@@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
end
local FLASH_ATTN_ROOT = get_config("flash-attn")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
target("infiniop-nvidia")
set_kind("static")
add_deps("infini-utils")
......@@ -132,3 +136,53 @@ target("infiniccl-nvidia")
set_languages("cxx17")
target_end()
target("flash-attn-nvidia")
set_kind("shared")
set_default(false)
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart")
add_cugencodes("native")
before_build(function (target)
if FLASH_ATTN_ROOT ~= nil then
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
local PYTHON_INCLUDE = outdata
local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
local PYTHON_LIB_DIR = outdata
-- Include dirs
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
target:add("includedirs", TORCH_DIR .. "/include", {public = false})
target:add("includedirs", PYTHON_INCLUDE, {public = false})
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})
-- Link libraries
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR)
target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", "python3")
end
end)
if FLASH_ATTN_ROOT ~= nil then
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
-- Link options
add_ldflags("-Wl,--no-undefined", {force = true})
-- Compile options
add_cxflags("-fPIC", {force = true})
add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true})
set_values("cuda.rdc", false)
end
on_install(function (target) end)
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment