"tests/git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "9759527111a8103bbbd8fc863e50d728c5101c8e"
Commit 06362c94 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 add flash-attn compile target

parent 515e1eca
#ifdef ENABLE_ATEN
#pragma once #pragma once
#include "../context/context.hpp" #include "../context/context.hpp"
#include "../tensor.hpp" #include "../tensor.hpp"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#ifdef ENABLE_NVIDIA_API
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#endif
namespace infinicore::adaptor { namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) { inline at::ScalarType to_at_dtype(DataType dtype) {
...@@ -37,5 +40,9 @@ inline at::Device to_at_device(const Device &device) { ...@@ -37,5 +40,9 @@ inline at::Device to_at_device(const Device &device) {
at::Tensor to_aten_tensor(const infinicore::Tensor &t); at::Tensor to_aten_tensor(const infinicore::Tensor &t);
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream(); c10::cuda::CUDAStream get_cuda_stream();
} // namespace infinicore::adaptor #endif
\ No newline at end of file } // namespace infinicore::adaptor
#endif // ENABLE_ATEN
#ifdef ENABLE_FLASH_ATTN
#pragma once #pragma once
#include "aten_adaptor.hpp" #include "aten_adaptor.hpp"
...@@ -109,4 +110,5 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size ...@@ -109,4 +110,5 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits); int num_splits);
} // namespace flash } // namespace flash
\ No newline at end of file #endif // ENABLE_FLASH_ATTN
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp" #include "infinicore/adaptor/aten_adaptor.hpp"
namespace infinicore::adaptor { namespace infinicore::adaptor {
...@@ -31,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { ...@@ -31,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options); options);
} }
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() { c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal( return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
} }
} // namespace infinicore::adaptor #endif
\ No newline at end of file
} // namespace infinicore::adaptor
#endif // ENABLE_ATEN
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include "infinicore/adaptor/flash_attention_adaptor.hpp" #include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_varlen_impl::flashattn { namespace infinicore::op::mha_varlen_impl::flashattn {
struct PlannedMeta { struct PlannedMeta {
...@@ -38,6 +40,7 @@ void *plan(Tensor out, ...@@ -38,6 +40,7 @@ void *plan(Tensor out,
} }
void run(void *planned_meta) { void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
...@@ -77,6 +80,9 @@ void run(void *planned_meta) { ...@@ -77,6 +80,9 @@ void run(void *planned_meta) {
0.0, 0.0,
false, false,
std::nullopt); std::nullopt);
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
} }
void cleanup(void **planned_meta_ptr) { void cleanup(void **planned_meta_ptr) {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include "ops/linear.hpp" #include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp" #include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/mha_varlen.hpp" #include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp" #include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp" #include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp" #include "ops/paged_caching.hpp"
......
...@@ -226,6 +226,28 @@ if has_config("ninetoothed") then ...@@ -226,6 +226,28 @@ if has_config("ninetoothed") then
add_defines("ENABLE_NINETOOTHED") add_defines("ENABLE_NINETOOTHED")
end end
-- ATen
option("aten")
set_default(false)
set_showmenu(true)
set_description("Wether to link aten and torch libraries")
option_end()
-- Flash-Attn
option("flash-attn")
set_default(nil)
set_showmenu(true)
set_description("Path to flash-attention repo. If not set, flash-attention will not used.")
option_end()
if has_config("aten") then
add_defines("ENABLE_ATEN")
if get_config("flash-attn") ~= nil then
add_defines("ENABLE_FLASH_ATTN")
end
end
-- cuda graph -- cuda graph
option("graph") option("graph")
set_default(false) set_default(false)
...@@ -439,31 +461,40 @@ target("infinicore_cpp_api") ...@@ -439,31 +461,40 @@ target("infinicore_cpp_api")
add_linkdirs(INFINI_ROOT.."/lib") add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl") add_links("infiniop", "infinirt", "infiniccl")
-- ============================== if get_config("flash-attn") ~= nil then
-- LibTorch integration add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"})
-- ============================== if has_config("nv-gpu") then
local LIBTORCH_ROOT = ("/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch") add_deps("flash-attn-nvidia")
end
-- headers end
add_includedirs(
path.join(LIBTORCH_ROOT, "include"), before_build(function (target)
path.join(LIBTORCH_ROOT, "include/torch/csrc/api/include"), if has_config("aten") then
{ public = true } local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
) local TORCH_DIR = outdata
-- libraries target:add(
add_linkdirs(path.join(LIBTORCH_ROOT, "lib")) "includedirs",
path.join(TORCH_DIR, "include"),
-- core ATen / Torch libs path.join(TORCH_DIR, "include/torch/csrc/api/include"),
add_links( { public = true })
"torch",
"c10", target:add(
"torch_cuda", "linkdirs",
"c10_cuda" path.join(TORCH_DIR, "lib"),
) { public = true }
-- Flash attention lib )
add_linkdirs("/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build") target:add(
add_links("flash_attn") "links",
"torch",
"c10",
"torch_cuda",
"c10_cuda",
{ public = true }
)
end
end)
-- Add InfiniCore C++ source files (needed for RoPE and other nn modules) -- Add InfiniCore C++ source files (needed for RoPE and other nn modules)
add_files("src/infinicore/*.cc") add_files("src/infinicore/*.cc")
......
...@@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then ...@@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT) add_includedirs(CUTLASS_ROOT)
end end
local FLASH_ATTN_ROOT = get_config("flash-attn")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
target("infiniop-nvidia") target("infiniop-nvidia")
set_kind("static") set_kind("static")
add_deps("infini-utils") add_deps("infini-utils")
...@@ -132,3 +136,53 @@ target("infiniccl-nvidia") ...@@ -132,3 +136,53 @@ target("infiniccl-nvidia")
set_languages("cxx17") set_languages("cxx17")
target_end() target_end()
target("flash-attn-nvidia")
set_kind("shared")
set_default(false)
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart")
add_cugencodes("native")
before_build(function (target)
if FLASH_ATTN_ROOT ~= nil then
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
local PYTHON_INCLUDE = outdata
local outdata = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
local PYTHON_LIB_DIR = outdata
-- Include dirs
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false})
target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false})
target:add("includedirs", TORCH_DIR .. "/include", {public = false})
target:add("includedirs", PYTHON_INCLUDE, {public = false})
target:add("includedirs", CUTLASS_ROOT .. "/include", {public = false})
target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn", {public = false})
-- Link libraries
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR)
target:add("links", "torch", "torch_cuda", "torch_cpu", "c10", "c10_cuda", "torch_python", "python3")
end
end)
if FLASH_ATTN_ROOT ~= nil then
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp")
add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu")
-- Link options
add_ldflags("-Wl,--no-undefined", {force = true})
-- Compile options
add_cxflags("-fPIC", {force = true})
add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true})
set_values("cuda.rdc", false)
end
on_install(function (target) end)
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment