Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
06362c94
Commit
06362c94
authored
Mar 05, 2026
by
PanZezhong
Committed by
wooway777
Mar 05, 2026
Browse files
issue/1033 add flash-attn compile target
parent
515e1eca
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
135 additions
and
29 deletions
+135
-29
include/infinicore/adaptor/aten_adaptor.hpp
include/infinicore/adaptor/aten_adaptor.hpp
+8
-1
include/infinicore/adaptor/flash_attention_adaptor.hpp
include/infinicore/adaptor/flash_attention_adaptor.hpp
+3
-1
src/infinicore/adaptor/aten_adaptor.cc
src/infinicore/adaptor/aten_adaptor.cc
+7
-1
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
...e/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
+6
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+1
-1
xmake.lua
xmake.lua
+56
-25
xmake/nvidia.lua
xmake/nvidia.lua
+54
-0
No files found.
include/infinicore/adaptor/aten_adaptor.hpp
View file @
06362c94
#ifdef ENABLE_ATEN
#pragma once
#include "../context/context.hpp"
#include "../tensor.hpp"
#include <ATen/ATen.h>
#ifdef ENABLE_NVIDIA_API
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#endif
namespace
infinicore
::
adaptor
{
inline
at
::
ScalarType
to_at_dtype
(
DataType
dtype
)
{
...
...
@@ -37,5 +40,9 @@ inline at::Device to_at_device(const Device &device) {
at
::
Tensor
to_aten_tensor
(
const
infinicore
::
Tensor
&
t
);
#ifdef ENABLE_NVIDIA_API
c10
::
cuda
::
CUDAStream
get_cuda_stream
();
}
// namespace infinicore::adaptor
\ No newline at end of file
#endif
}
// namespace infinicore::adaptor
#endif // ENABLE_ATEN
include/infinicore/adaptor/flash_attention_adaptor.hpp
View file @
06362c94
#ifdef ENABLE_FLASH_ATTN
#pragma once
#include "aten_adaptor.hpp"
...
...
@@ -109,4 +110,5 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
);
}
// namespace flash
\ No newline at end of file
}
// namespace flash
#endif // ENABLE_FLASH_ATTN
src/infinicore/adaptor/aten_adaptor.cc
View file @
06362c94
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp"
namespace
infinicore
::
adaptor
{
...
...
@@ -31,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options
);
}
#ifdef ENABLE_NVIDIA_API
c10
::
cuda
::
CUDAStream
get_cuda_stream
()
{
return
c10
::
cuda
::
getStreamFromExternal
(
cudaStream_t
(
infinicore
::
context
::
getStream
()),
infinicore
::
context
::
getDevice
().
getIndex
());
}
}
// namespace infinicore::adaptor
\ No newline at end of file
#endif
}
// namespace infinicore::adaptor
#endif // ENABLE_ATEN
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
View file @
06362c94
...
...
@@ -2,6 +2,8 @@
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace
infinicore
::
op
::
mha_varlen_impl
::
flashattn
{
struct
PlannedMeta
{
...
...
@@ -38,6 +40,7 @@ void *plan(Tensor out,
}
void
run
(
void
*
planned_meta
)
{
#ifdef ENABLE_FLASH_ATTN
c10
::
cuda
::
CUDAStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
...
...
@@ -77,6 +80,9 @@ void run(void *planned_meta) {
0.0
,
false
,
std
::
nullopt
);
#else
throw
std
::
runtime_error
(
"FlashAttention is not enabled in this build"
);
#endif
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
...
...
src/infinicore/pybind11/ops.hpp
View file @
06362c94
...
...
@@ -12,8 +12,8 @@
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
...
...
xmake.lua
View file @
06362c94
...
...
@@ -226,6 +226,28 @@ if has_config("ninetoothed") then
add_defines
(
"ENABLE_NINETOOTHED"
)
end
-- ATen
option
(
"aten"
)
set_default
(
false
)
set_showmenu
(
true
)
set_description
(
"Wether to link aten and torch libraries"
)
option_end
()
-- Flash-Attn
option
(
"flash-attn"
)
set_default
(
nil
)
set_showmenu
(
true
)
set_description
(
"Path to flash-attention repo. If not set, flash-attention will not used."
)
option_end
()
if
has_config
(
"aten"
)
then
add_defines
(
"ENABLE_ATEN"
)
if
get_config
(
"flash-attn"
)
~=
nil
then
add_defines
(
"ENABLE_FLASH_ATTN"
)
end
end
-- cuda graph
option
(
"graph"
)
set_default
(
false
)
...
...
@@ -439,31 +461,40 @@ target("infinicore_cpp_api")
add_linkdirs
(
INFINI_ROOT
..
"/lib"
)
add_links
(
"infiniop"
,
"infinirt"
,
"infiniccl"
)
-- ==============================
-- LibTorch integration
-- ==============================
local
LIBTORCH_ROOT
=
(
"/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch"
)
-- headers
add_includedirs
(
path
.
join
(
LIBTORCH_ROOT
,
"include"
),
path
.
join
(
LIBTORCH_ROOT
,
"include/torch/csrc/api/include"
),
{
public
=
true
}
)
-- libraries
add_linkdirs
(
path
.
join
(
LIBTORCH_ROOT
,
"lib"
))
-- core ATen / Torch libs
add_links
(
"torch"
,
"c10"
,
"torch_cuda"
,
"c10_cuda"
)
-- Flash attention lib
add_linkdirs
(
"/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build"
)
add_links
(
"flash_attn"
)
if
get_config
(
"flash-attn"
)
~=
nil
then
add_installfiles
(
"(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so"
,
{
prefixdir
=
"lib"
})
if
has_config
(
"nv-gpu"
)
then
add_deps
(
"flash-attn-nvidia"
)
end
end
before_build
(
function
(
target
)
if
has_config
(
"aten"
)
then
local
outdata
=
os
.
iorunv
(
"python"
,
{
"-c"
,
"import torch, os; print(os.path.dirname(torch.__file__))"
}):
trim
()
local
TORCH_DIR
=
outdata
target
:
add
(
"includedirs"
,
path
.
join
(
TORCH_DIR
,
"include"
),
path
.
join
(
TORCH_DIR
,
"include/torch/csrc/api/include"
),
{
public
=
true
})
target
:
add
(
"linkdirs"
,
path
.
join
(
TORCH_DIR
,
"lib"
),
{
public
=
true
}
)
target
:
add
(
"links"
,
"torch"
,
"c10"
,
"torch_cuda"
,
"c10_cuda"
,
{
public
=
true
}
)
end
end
)
-- Add InfiniCore C++ source files (needed for RoPE and other nn modules)
add_files
(
"src/infinicore/*.cc"
)
...
...
xmake/nvidia.lua
View file @
06362c94
...
...
@@ -9,6 +9,10 @@ if CUTLASS_ROOT ~= nil then
add_includedirs
(
CUTLASS_ROOT
)
end
local
FLASH_ATTN_ROOT
=
get_config
(
"flash-attn"
)
local
INFINI_ROOT
=
os.getenv
(
"INFINI_ROOT"
)
or
(
os.getenv
(
is_host
(
"windows"
)
and
"HOMEPATH"
or
"HOME"
)
..
"/.infini"
)
target
(
"infiniop-nvidia"
)
set_kind
(
"static"
)
add_deps
(
"infini-utils"
)
...
...
@@ -132,3 +136,53 @@ target("infiniccl-nvidia")
set_languages
(
"cxx17"
)
target_end
()
target
(
"flash-attn-nvidia"
)
set_kind
(
"shared"
)
set_default
(
false
)
set_policy
(
"build.cuda.devlink"
,
true
)
set_toolchains
(
"cuda"
)
add_links
(
"cudart"
)
add_cugencodes
(
"native"
)
before_build
(
function
(
target
)
if
FLASH_ATTN_ROOT
~=
nil
then
local
outdata
=
os
.
iorunv
(
"python"
,
{
"-c"
,
"import torch, os; print(os.path.dirname(torch.__file__))"
}):
trim
()
local
TORCH_DIR
=
outdata
local
outdata
=
os
.
iorunv
(
"python"
,
{
"-c"
,
"import sysconfig; print(sysconfig.get_paths()['include'])"
}):
trim
()
local
PYTHON_INCLUDE
=
outdata
local
outdata
=
os
.
iorunv
(
"python"
,
{
"-c"
,
"import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"
}):
trim
()
local
PYTHON_LIB_DIR
=
outdata
-- Include dirs
target
:
add
(
"includedirs"
,
FLASH_ATTN_ROOT
..
"/csrc/flash_attn/src"
,
{
public
=
false
})
target
:
add
(
"includedirs"
,
TORCH_DIR
..
"/include/torch/csrc/api/include"
,
{
public
=
false
})
target
:
add
(
"includedirs"
,
TORCH_DIR
..
"/include"
,
{
public
=
false
})
target
:
add
(
"includedirs"
,
PYTHON_INCLUDE
,
{
public
=
false
})
target
:
add
(
"includedirs"
,
CUTLASS_ROOT
..
"/include"
,
{
public
=
false
})
target
:
add
(
"includedirs"
,
FLASH_ATTN_ROOT
..
"/csrc/flash_attn"
,
{
public
=
false
})
-- Link libraries
target
:
add
(
"linkdirs"
,
TORCH_DIR
..
"/lib"
,
PYTHON_LIB_DIR
)
target
:
add
(
"links"
,
"torch"
,
"torch_cuda"
,
"torch_cpu"
,
"c10"
,
"c10_cuda"
,
"torch_python"
,
"python3"
)
end
end
)
if
FLASH_ATTN_ROOT
~=
nil
then
add_files
(
FLASH_ATTN_ROOT
..
"/csrc/flash_attn/flash_api.cpp"
)
add_files
(
FLASH_ATTN_ROOT
..
"/csrc/flash_attn/src/*.cu"
)
-- Link options
add_ldflags
(
"-Wl,--no-undefined"
,
{
force
=
true
})
-- Compile options
add_cxflags
(
"-fPIC"
,
{
force
=
true
})
add_cuflags
(
"-Xcompiler=-fPIC"
)
add_cuflags
(
"--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math"
,
{
force
=
true
})
set_values
(
"cuda.rdc"
,
false
)
end
on_install
(
function
(
target
)
end
)
target_end
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment