Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
6ac8f906
Commit
6ac8f906
authored
Jan 26, 2026
by
wooway777
Browse files
issue/919 - ninetoothed flash attention
parent
47843aa6
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
911 additions
and
5 deletions
+911
-5
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/flash_attention.hpp
include/infinicore/ops/flash_attention.hpp
+12
-0
include/infiniop.h
include/infiniop.h
+1
-0
include/infiniop/ops/flash_attention.h
include/infiniop/ops/flash_attention.h
+36
-0
python/infinicore/nn/functional/__init__.py
python/infinicore/nn/functional/__init__.py
+6
-4
python/infinicore/nn/functional/flash_attention.py
python/infinicore/nn/functional/flash_attention.py
+34
-0
src/infinicore/ops/flash_attention/flash_attention.cc
src/infinicore/ops/flash_attention/flash_attention.cc
+31
-0
src/infinicore/ops/flash_attention/flash_attention_infiniop.cc
...nfinicore/ops/flash_attention/flash_attention_infiniop.cc
+55
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+3
-1
src/infinicore/pybind11/ops/flash_attention.hpp
src/infinicore/pybind11/ops/flash_attention.hpp
+22
-0
src/infiniop/ops/flash_attention/ninetoothed/build.py
src/infiniop/ops/flash_attention/ninetoothed/build.py
+46
-0
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
+147
-0
src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py
...finiop/ops/flash_attention/ninetoothed/flash_attention.py
+281
-0
src/infiniop/ops/flash_attention/operator.cc
src/infiniop/ops/flash_attention/operator.cc
+121
-0
test/infinicore/ops/flash_attention.py
test/infinicore/ops/flash_attention.py
+115
-0
No files found.
include/infinicore/ops.hpp
View file @
6ac8f906
...
...
@@ -5,6 +5,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
...
...
include/infinicore/ops/flash_attention.hpp
0 → 100644
View file @
6ac8f906
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
FlashAttention
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
float
,
bool
);
Tensor
flash_attention
(
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
);
void
flash_attention_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
);
}
// namespace infinicore::op
include/infiniop.h
View file @
6ac8f906
...
...
@@ -10,6 +10,7 @@
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
...
...
include/infiniop/ops/flash_attention.h
0 → 100644
View file @
6ac8f906
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopFlashAttentionDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateFlashAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopFlashAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
float
scale
,
char
is_causal
);
__C
__export
infiniStatus_t
infiniopGetFlashAttentionWorkspaceSize
(
infiniopFlashAttentionDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopFlashAttention
(
infiniopFlashAttentionDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k
,
const
void
*
v
,
const
void
*
total_kv_len
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyFlashAttentionDescriptor
(
infiniopFlashAttentionDescriptor_t
desc
);
#endif
python/infinicore/nn/functional/__init__.py
View file @
6ac8f906
from
.causal_softmax
import
causal_softmax
from
.embedding
import
embedding
from
.flash_attention
import
flash_attention
from
.linear
import
linear
from
.random_sample
import
random_sample
from
.rms_norm
import
rms_norm
...
...
@@ -9,12 +10,13 @@ from .swiglu import swiglu
__all__
=
[
"causal_softmax"
,
"embedding"
,
"flash_attention"
,
"linear"
,
"random_sample"
,
"rms_norm"
,
"RopeAlgo"
,
"rope"
,
"silu"
,
"swiglu"
,
"linear"
,
"embedding"
,
"rope"
,
"RopeAlgo"
,
]
python/infinicore/nn/functional/flash_attention.py
0 → 100644
View file @
6ac8f906
import
math
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
flash_attention
(
query
,
key
,
value
,
total_kv_len
,
attn_mask
=
None
,
dropout_p
=
0
,
is_causal
=
False
,
scale
=
None
,
enable_gqa
=
False
,
):
assert
attn_mask
is
None
and
dropout_p
==
0
and
not
enable_gqa
emb_dim
=
query
.
shape
[
-
1
]
if
scale
is
None
:
scale
=
1
/
math
.
sqrt
(
emb_dim
)
return
Tensor
(
_infinicore
.
flash_attention
(
query
.
_underlying
,
key
.
_underlying
,
value
.
_underlying
,
total_kv_len
.
_underlying
,
scale
,
is_causal
,
)
)
src/infinicore/ops/flash_attention/flash_attention.cc
0 → 100644
View file @
6ac8f906
#include "infinicore/ops/flash_attention.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
FlashAttention
);
FlashAttention
::
FlashAttention
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k
,
v
);
INFINICORE_GRAPH_OP_DISPATCH
(
out
->
device
().
getType
(),
out
,
q
,
k
,
v
,
total_kv_len
,
scale
,
is_causal
);
}
void
FlashAttention
::
execute
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
)
{
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
FlashAttention
,
out
,
q
,
k
,
v
,
total_kv_len
,
scale
,
is_causal
);
}
Tensor
flash_attention
(
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
)
{
Shape
shape
=
q
->
shape
();
int
idx
=
shape
.
size
()
-
1
;
shape
[
idx
]
=
v
->
shape
()[
idx
];
auto
out
=
Tensor
::
empty
(
shape
,
q
->
dtype
(),
q
->
device
());
flash_attention_
(
out
,
q
,
k
,
v
,
total_kv_len
,
scale
,
is_causal
);
return
out
;
}
void
flash_attention_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
)
{
FlashAttention
::
execute
(
out
,
q
,
k
,
v
,
total_kv_len
,
scale
,
is_causal
);
}
}
// namespace infinicore::op
src/infinicore/ops/flash_attention/flash_attention_infiniop.cc
0 → 100644
View file @
6ac8f906
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/flash_attention.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
flash_attention_impl
::
infiniop
{
INFINIOP_CACHABLE_DESCRIPTOR
(
Descriptor
,
FlashAttention
,
100
);
struct
PlannedMeta
{
std
::
shared_ptr
<
Descriptor
>
descriptor
;
graph
::
GraphTensor
workspace
,
out
,
q
,
k
,
v
,
total_kv_len
;
float
scale
;
bool
is_causal
;
};
void
*
plan
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k
,
v
,
total_kv_len
,
scale
,
is_causal
);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE
(
Descriptor
,
descriptor
,
FlashAttention
,
seed
,
out
->
desc
(),
q
->
desc
(),
k
->
desc
(),
v
->
desc
(),
total_kv_len
->
desc
(),
scale
,
is_causal
);
INFINIOP_WORKSPACE_TENSOR
(
workspace
,
FlashAttention
,
descriptor
);
auto
planned
=
new
PlannedMeta
{
descriptor
,
graph
::
GraphTensor
(
workspace
),
graph
::
GraphTensor
(
out
),
graph
::
GraphTensor
(
q
),
graph
::
GraphTensor
(
k
),
graph
::
GraphTensor
(
v
),
graph
::
GraphTensor
(
total_kv_len
),
scale
,
is_causal
};
return
planned
;
}
void
run
(
void
*
planned_meta
)
{
auto
planned
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
INFINICORE_CHECK_ERROR
(
infiniopFlashAttention
(
planned
->
descriptor
->
desc
,
planned
->
workspace
->
data
(),
planned
->
workspace
->
numel
(),
planned
->
out
->
data
(),
planned
->
q
->
data
(),
planned
->
k
->
data
(),
planned
->
v
->
data
(),
planned
->
total_kv_len
->
data
(),
context
::
getStream
()));
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
FlashAttention
,
&
plan
,
&
run
,
&
cleanup
);
}
// namespace infinicore::op::flash_attention_impl::infiniop
src/infinicore/pybind11/ops.hpp
View file @
6ac8f906
...
...
@@ -7,6 +7,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
...
...
@@ -29,13 +30,14 @@ inline void bind(py::module &m) {
bind_add_rms_norm
(
m
);
bind_attention
(
m
);
bind_causal_softmax
(
m
);
bind_
random_sample
(
m
);
bind_
flash_attention
(
m
);
bind_linear
(
m
);
bind_matmul
(
m
);
bind_mul
(
m
);
bind_paged_attention
(
m
);
bind_paged_attention_prefill
(
m
);
bind_paged_caching
(
m
);
bind_random_sample
(
m
);
bind_rearrange
(
m
);
bind_rms_norm
(
m
);
bind_silu
(
m
);
...
...
src/infinicore/pybind11/ops/flash_attention.hpp
0 → 100644
View file @
6ac8f906
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/flash_attention.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
inline
void
bind_flash_attention
(
py
::
module
&
m
)
{
m
.
def
(
"flash_attention"
,
&
op
::
flash_attention
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"total_kv_len"
),
py
::
arg
(
"scale"
),
py
::
arg
(
"is_causal"
));
}
}
// namespace infinicore::ops
src/infiniop/ops/flash_attention/ninetoothed/build.py
0 → 100644
View file @
6ac8f906
import
ninetoothed
from
.
import
flash_attention
from
.flash_attention
import
CausalVariant
import
infiniop.ninetoothed.build
import
torch
def
build
():
if
torch
.
cuda
.
is_available
():
device_count
=
torch
.
cuda
.
device_count
()
for
i
in
range
(
device_count
):
device_name
=
torch
.
cuda
.
get_device_name
(
i
).
lower
()
if
"metax"
in
device_name
:
return
with_kv_cache_values
=
(
0
,)
emb_dim_values
=
(
16
,
32
,
64
,
128
,
256
)
is_causal_values
=
(
0
,
1
)
with_attn_mask_values
=
(
0
,)
causal_variant_values
=
(
CausalVariant
.
UPPER_LEFT
,
CausalVariant
.
LOWER_RIGHT
)
dtype_values
=
(
ninetoothed
.
float16
,
ninetoothed
.
bfloat16
,
ninetoothed
.
float32
)
block_size_m_values
=
(
256
,)
block_size_n_values
=
(
64
,)
constexpr_param_grid
=
{
"with_kv_cache"
:
with_kv_cache_values
,
"emb_dim"
:
emb_dim_values
,
"is_causal"
:
is_causal_values
,
"with_attn_mask"
:
with_attn_mask_values
,
"causal_variant"
:
causal_variant_values
,
"dtype"
:
dtype_values
,
"block_size_m"
:
block_size_m_values
,
"block_size_n"
:
block_size_n_values
,
}
infiniop
.
ninetoothed
.
build
.
build
(
flash_attention
.
premake
,
constexpr_param_grid
,
caller
=
"cuda"
,
op_name
=
"flash_attention"
,
output_dir
=
infiniop
.
ninetoothed
.
build
.
BUILD_DIRECTORY_PATH
,
)
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
0 → 100644
View file @
6ac8f906
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
#define __FLASH_ATTENTION_DESCRIPTOR_H__
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/flash_attention.h"
#include "../../../ninetoothed/utils.h"
namespace
op
::
flash_attention
::
ninetoothed
{
class
Descriptor
final
:
public
InfiniopDescriptor
{
public:
Descriptor
(
infiniopHandle_t
handle
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
double
scale
,
char
is_causal
)
:
InfiniopDescriptor
{
handle
->
device
,
handle
->
device_id
},
_query_shape
{
q_desc
->
shape
()},
_query_strides
{
q_desc
->
strides
()},
_key_shape
{
k_desc
->
shape
()},
_key_strides
{
k_desc
->
strides
()},
_value_shape
{
v_desc
->
shape
()},
_value_strides
{
v_desc
->
strides
()},
_total_kv_shape
{
total_kv_len
->
shape
()},
_total_kv_strides
{
total_kv_len
->
strides
()},
_output_strides
{
out_desc
->
strides
()},
_dtype
{
q_desc
->
dtype
()},
_scale
{
scale
},
_is_causal
{
is_causal
}
{
}
~
Descriptor
()
=
default
;
size_t
get_workspace_size
()
const
{
return
0
;
}
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k
,
const
void
*
v
,
const
void
*
total_kv_len
,
void
*
stream
)
const
{
uint64_t
empty_shape
[
4
];
int64_t
empty_strides
[
4
];
auto
query
{
::
ninetoothed
::
Tensor
{
q
,
_query_shape
,
_query_strides
}};
auto
key
{
::
ninetoothed
::
Tensor
{
k
,
_key_shape
,
_key_strides
}};
auto
value
{
::
ninetoothed
::
Tensor
{
v
,
_value_shape
,
_value_strides
}};
auto
total_kv_length
{
::
ninetoothed
::
Tensor
{
total_kv_len
,
_total_kv_shape
,
_total_kv_strides
}};
NineToothedTensor
attn_mask
{
nullptr
,
empty_shape
,
empty_strides
};
NineToothedTensor
is_causal
;
NineToothedTensor
scale
{
const_cast
<
double
*>
(
&
_scale
),
nullptr
,
nullptr
};
auto
output
{
::
ninetoothed
::
Tensor
{
out
,
_query_shape
,
_output_strides
}};
NineToothedTensor
with_attn_mask
;
NineToothedTensor
causal_variant
;
const
auto
with_kv_cache_
{
0
};
const
auto
emb_dim_
{
_query_shape
[
3
]};
const
auto
is_causal_
{
_is_causal
};
const
auto
with_attn_mask_
{
0
};
const
auto
causal_variant_
{
2
};
const
auto
dtype_
{
_dtype
};
constexpr
auto
block_size_m_
{
256
};
constexpr
auto
block_size_n_
{
64
};
if
(
launch_flash_attention
(
stream
,
query
,
key
,
value
,
total_kv_length
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
with_kv_cache_
,
emb_dim_
,
is_causal_
,
with_attn_mask_
,
causal_variant_
,
dtype_
,
block_size_m_
,
block_size_n_
))
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
double
scale
,
char
is_causal
)
{
*
desc
=
new
Descriptor
{
handle
,
out_desc
,
q_desc
,
k_desc
,
v_desc
,
total_kv_len
,
scale
,
is_causal
};
return
INFINI_STATUS_SUCCESS
;
}
private:
using
Size
=
::
ninetoothed
::
Tensor
<>::
Size
;
using
Stride
=
::
ninetoothed
::
Tensor
<>::
Stride
;
std
::
vector
<
Size
>
_query_shape
;
std
::
vector
<
Stride
>
_query_strides
;
std
::
vector
<
Size
>
_key_shape
;
std
::
vector
<
Stride
>
_key_strides
;
std
::
vector
<
Size
>
_value_shape
;
std
::
vector
<
Stride
>
_value_strides
;
std
::
vector
<
Size
>
_total_kv_shape
;
std
::
vector
<
Stride
>
_total_kv_strides
;
std
::
vector
<
Stride
>
_output_strides
;
infiniDtype_t
_dtype
;
double
_scale
;
char
_is_causal
;
};
}
// namespace op::flash_attention::ninetoothed
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__
src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py
0 → 100644
View file @
6ac8f906
import
enum
import
functools
import
ninetoothed
import
ninetoothed.language
as
ntl
from
ninetoothed
import
Tensor
BLOCK_SIZE_M
=
ninetoothed
.
block_size
()
BLOCK_SIZE_N
=
ninetoothed
.
block_size
()
class
CausalVariant
(
enum
.
IntEnum
):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
UPPER_LEFT
=
enum
.
auto
()
LOWER_RIGHT
=
enum
.
auto
()
def
arrangement
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
with_kv_cache
,
block_size_m
=
None
,
block_size_n
=
None
,
):
def
arrange_query_or_output
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
-
1
)).
tile
(
(
1
,
query
.
shape
[
-
3
]
//
key
.
shape
[
-
3
],
1
,
1
)
)
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
2
,
3
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_key_or_value
(
input
):
arranged
=
(
input
.
tile
((
1
,
1
,
block_size_n
,
-
1
))
.
tile
((
1
,
1
,
-
1
,
-
1
))
.
expand
((
-
1
,
-
1
,
query_arranged
.
shape
[
-
2
],
-
1
))
)
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
,
3
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_total_kv_len
(
input
,
shape
):
arranged
=
input
.
tile
((
1
,))
arranged
=
arranged
.
unsqueeze
(
1
).
unsqueeze
(
2
).
unsqueeze
(
3
).
expand
(
shape
)
return
arranged
def
arrange_present_key_or_present_value
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
block_size_n
))
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_attn_mask
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
block_size_n
)).
tile
((
1
,
1
,
1
,
-
1
))
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
,
2
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
if
block_size_m
is
None
:
block_size_m
=
BLOCK_SIZE_M
if
block_size_n
is
None
:
block_size_n
=
BLOCK_SIZE_N
query_arranged
=
arrange_query_or_output
(
query
)
key_arranged
=
arrange_key_or_value
(
key
)
value_arranged
=
arrange_key_or_value
(
value
)
total_kv_len_arranged
=
arrange_total_kv_len
(
total_kv_len
,
query_arranged
.
shape
)
present_key_arranged
=
arrange_present_key_or_present_value
(
present_key
)
present_value_arranged
=
arrange_present_key_or_present_value
(
present_value
)
present_key_slot_arranged
=
arrange_present_key_or_present_value
(
present_key_slot
)
present_value_slot_arranged
=
arrange_present_key_or_present_value
(
present_value_slot
)
attn_mask_arranged
=
arrange_attn_mask
(
attn_mask
)
is_causal_arranged
=
is_causal
scale_arranged
=
scale
output_arranged
=
arrange_query_or_output
(
output
)
with_attn_mask_arranged
=
with_attn_mask
causal_variant_arranged
=
causal_variant
if
with_kv_cache
:
return
(
query_arranged
,
key_arranged
,
value_arranged
,
total_kv_len_arranged
,
present_key_arranged
,
present_value_arranged
,
present_key_slot_arranged
,
present_value_slot_arranged
,
attn_mask_arranged
,
is_causal_arranged
,
scale_arranged
,
output_arranged
,
with_attn_mask_arranged
,
causal_variant_arranged
,
)
return
(
query_arranged
,
key_arranged
,
value_arranged
,
total_kv_len_arranged
,
attn_mask_arranged
,
is_causal_arranged
,
scale_arranged
,
output_arranged
,
with_attn_mask_arranged
,
causal_variant_arranged
,
)
def
application_with_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
):
present_key_slot
=
present_key
# noqa: F841
present_value_slot
=
present_value
# noqa: F841
application_without_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
)
def
application_without_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
):
actual_kv_len
=
total_kv_len
[
0
]
for
i
in
range
(
query
.
shape
[
0
]):
query_i
=
(
1.4426950408889634
*
scale
*
query
[
i
]).
to
(
query
[
i
].
dtype
)
acc
=
ntl
.
zeros
((
query_i
.
shape
[
-
2
],
query_i
.
shape
[
-
1
]),
dtype
=
ntl
.
float32
)
lse
=
ntl
.
full
((
query_i
.
shape
[
-
2
],),
1
,
dtype
=
ntl
.
float32
)
max
=
ntl
.
full
((
query_i
.
shape
[
-
2
],),
float
(
"-inf"
),
dtype
=
ntl
.
float32
)
for
j
in
range
(
-
(
-
actual_kv_len
//
key
.
dtype
.
shape
[
0
])):
qk
=
ntl
.
dot
(
query_i
,
ntl
.
trans
(
key
[
j
]))
key_pos
=
key
[
j
].
offsets
(
-
2
)
qk
=
ntl
.
where
(
key_pos
<
actual_kv_len
,
qk
,
float
(
"-inf"
))
if
with_attn_mask
:
qk
+=
attn_mask
[
j
]
if
is_causal
:
query_pos
=
query
[
i
].
offsets
(
-
2
)
if
causal_variant
==
2
:
# CausalVariant.LOWER_RIGHT:
mask
=
(
query_pos
[:,
None
]
+
actual_kv_len
-
query
.
source
.
shape
[
-
2
]
>=
key_pos
[
None
,
:]
)
else
:
mask
=
query_pos
[:,
None
]
>=
key_pos
[
None
,
:]
qk
=
ntl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
next_max
=
ntl
.
maximum
(
max
,
ntl
.
max
(
qk
,
1
))
stable_qk
=
ntl
.
exp2
(
qk
-
next_max
[:,
None
])
alpha
=
ntl
.
exp2
(
max
-
next_max
)
acc
=
acc
*
alpha
[:,
None
]
+
ntl
.
dot
(
stable_qk
.
to
(
value
[
i
].
dtype
),
value
[
j
])
max
=
next_max
lse
=
lse
*
alpha
+
ntl
.
sum
(
stable_qk
,
1
)
acc
/=
lse
[:,
None
]
output
[
i
]
=
acc
# noqa: F841
def
premake
(
with_kv_cache
,
emb_dim
=
None
,
is_causal
=
None
,
with_attn_mask
=
None
,
causal_variant
=
None
,
dtype
=
None
,
block_size_m
=
None
,
block_size_n
=
None
,
):
arrangement_
=
functools
.
partial
(
arrangement
,
with_kv_cache
=
with_kv_cache
,
block_size_m
=
block_size_m
,
block_size_n
=
block_size_n
,
)
query
,
key
,
value
,
attn_mask
,
output
=
(
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
(
None
,
None
,
None
,
{
"constexpr"
:
True
,
"upper_bound"
:
128
}),
)
for
_
in
range
(
5
)
)
total_kv_len
=
Tensor
(
1
,
dtype
=
ninetoothed
.
int32
)
present_key
,
present_value
,
present_key_slot
,
present_value_slot
=
(
Tensor
(
4
,
dtype
=
dtype
)
for
_
in
range
(
4
)
)
scale
=
Tensor
(
0
,
dtype
=
ninetoothed
.
float64
)
is_causal
=
Tensor
(
0
,
constexpr
=
True
,
value
=
is_causal
)
with_attn_mask
=
Tensor
(
0
,
constexpr
=
True
,
value
=
with_attn_mask
)
causal_variant
=
Tensor
(
0
,
constexpr
=
True
,
value
=
causal_variant
)
if
emb_dim
is
not
None
:
for
tensor
in
(
query
,
key
,
value
,
attn_mask
,
output
):
tensor
.
shape
=
tensor
.
shape
[:
-
1
]
+
(
emb_dim
,)
if
with_kv_cache
:
application
=
application_with_kv_cache
else
:
application
=
application_without_kv_cache
tensors
=
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
)
return
arrangement_
,
application
,
tensors
src/infiniop/ops/flash_attention/operator.cc
0 → 100644
View file @
6ac8f906
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/flash_attention.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
#include "ninetoothed/descriptor.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateFlashAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopFlashAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
float
scale
,
char
is_causal
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::flash_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
q_desc, \
k_desc, \
v_desc, \
total_kv_len, \
scale, \
is_causal);
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetFlashAttentionWorkspaceSize
(
infiniopFlashAttentionDescriptor_t
desc
,
size_t
*
size
)
{
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET_SIZE
}
__C
infiniStatus_t
infiniopFlashAttention
(
infiniopFlashAttentionDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k
,
const
void
*
v
,
const
void
*
total_kv_len
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream);
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyFlashAttentionDescriptor
(
infiniopFlashAttentionDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
test/infinicore/ops/flash_attention.py
0 → 100644
View file @
6ac8f906
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TensorInitializer
,
TestCase
,
GenericTestRunner
,
)
# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal)
# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim)
_TEST_CASES_DATA
=
[
((
1
,
1
,
2
,
16
),
(
1
,
1
,
8
,
16
),
(
1
,
1
,
8
,
16
),
None
,
0.0
,
False
),
((
1
,
2
,
128
,
16
),
(
1
,
2
,
256
,
16
),
(
1
,
2
,
256
,
16
),
None
,
0.0
,
False
),
((
1
,
1
,
4
,
32
),
(
1
,
1
,
32
,
32
),
(
1
,
1
,
32
,
32
),
None
,
0.0
,
True
),
((
1
,
8
,
256
,
16
),
(
1
,
8
,
512
,
16
),
(
1
,
8
,
512
,
16
),
None
,
0.0
,
True
),
((
1
,
8
,
4
,
16
),
(
1
,
8
,
64
,
16
),
(
1
,
8
,
64
,
16
),
None
,
0.0
,
False
),
((
8
,
28
,
256
,
128
),
(
8
,
28
,
512
,
128
),
(
8
,
28
,
512
,
128
),
None
,
0.0
,
True
),
]
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
bfloat16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
}
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
def
parse_test_cases
():
import
random
cases
=
[]
for
q_shape
,
k_shape
,
v_shape
,
attn_mask
,
dropout_p
,
is_causal
in
_TEST_CASES_DATA
:
for
dtype
in
_TENSOR_DTYPES
:
tol
=
_TOLERANCE_MAP
[
dtype
]
q_spec
=
TensorSpec
.
from_tensor
(
q_shape
,
None
,
dtype
)
k_spec
=
TensorSpec
.
from_tensor
(
k_shape
,
None
,
dtype
)
v_spec
=
TensorSpec
.
from_tensor
(
v_shape
,
None
,
dtype
)
len_shape
=
(
q_shape
[
0
],)
total_len
=
random
.
randint
(
1
,
k_shape
[
2
])
total_kv_len_spec
=
TensorSpec
.
from_tensor
(
len_shape
,
None
,
infinicore
.
int64
,
init_mode
=
TensorInitializer
.
RANDINT
,
low
=
total_len
,
high
=
total_len
+
1
,
)
kwargs
=
{
"attn_mask"
:
attn_mask
,
"dropout_p"
:
dropout_p
,
"is_causal"
:
is_causal
,
}
# remove None keys
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
v
is
not
None
}
cases
.
append
(
TestCase
(
inputs
=
[
q_spec
,
k_spec
,
v_spec
,
total_kv_len_spec
,
total_len
],
kwargs
=
kwargs
,
output_spec
=
None
,
comparison_target
=
None
,
tolerance
=
tol
,
description
=
"Flash Attention"
,
)
)
return
cases
def
torch_flash_attn
(
q
,
k
,
v
,
total_kv_len
,
cheat
,
**
kwargs
):
k_slice
=
k
[:,
:,
:
cheat
,
:]
v_slice
=
v
[:,
:,
:
cheat
,
:]
return
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k_slice
,
v_slice
,
**
kwargs
)
def
infini_flash_attn
(
q
,
k
,
v
,
total_kv_len
,
cheat
,
**
kwargs
):
return
infinicore
.
nn
.
functional
.
flash_attention
(
q
,
k
,
v
,
total_kv_len
,
**
kwargs
)
class
OpTest
(
BaseOperatorTest
):
"""ScaledDotProductAttention operator test with simplified implementation"""
def
__init__
(
self
):
super
().
__init__
(
"ScaledDotProductAttention"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
return
torch_flash_attn
(
*
args
,
**
kwargs
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
return
infini_flash_attn
(
*
args
,
**
kwargs
)
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment