Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2033d805
Commit
2033d805
authored
Feb 03, 2026
by
zhanghj2
Browse files
支持纯bf16
parent
58b43d4a
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
867 additions
and
68 deletions
+867
-68
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+1
-1
csrc/sm90/decode/dense/config.h
csrc/sm90/decode/dense/config.h
+1
-1
csrc/sm90/decode/dense/splitkv_mla.cuh
csrc/sm90/decode/dense/splitkv_mla.cuh
+591
-11
csrc/sm90/decode/dense/traits.h
csrc/sm90/decode/dense/traits.h
+95
-54
csrc/utils.h
csrc/utils.h
+177
-0
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+2
-1
No files found.
csrc/api/dense_decode.h
View file @
2033d805
...
@@ -75,7 +75,7 @@ dense_attn_decode_interface(
...
@@ -75,7 +75,7 @@ dense_attn_decode_interface(
const
int
num_heads
=
num_heads_k
;
const
int
num_heads
=
num_heads_k
;
q
=
q
.
view
({
batch_size
,
seqlen_q_ori
,
num_heads_k
,
num_q_heads_per_hk
,
head_size_k
}).
transpose
(
2
,
3
)
q
=
q
.
view
({
batch_size
,
seqlen_q_ori
,
num_heads_k
,
num_q_heads_per_hk
,
head_size_k
}).
transpose
(
2
,
3
)
.
reshape
({
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
});
.
reshape
({
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
});
int
num_sm_parts
=
std
::
max
(
arch
.
num_sms
/
num_heads_k
/
cutlass
::
ceil_div
(
seqlen_q_ori
*
num_heads_q
/
num_heads_k
,
6
4
),
1
);
int
num_sm_parts
=
std
::
max
(
arch
.
num_sms
/
num_heads_k
/
cutlass
::
ceil_div
(
seqlen_q_ori
*
num_heads_q
/
num_heads_k
,
1
6
),
1
);
KU_CHECK_SHAPE
(
q
,
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
);
KU_CHECK_SHAPE
(
q
,
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
);
KU_CHECK_SHAPE
(
kcache
,
num_blocks
,
page_block_size
,
num_heads_k
,
head_size_k
);
KU_CHECK_SHAPE
(
kcache
,
num_blocks
,
page_block_size
,
num_heads_k
,
head_size_k
);
...
...
csrc/sm90/decode/dense/config.h
View file @
2033d805
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
namespace
Config
{
namespace
Config
{
static
constexpr
int
BLOCK_SIZE_M
=
6
4
;
static
constexpr
int
BLOCK_SIZE_M
=
1
6
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
64
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
64
;
static
constexpr
int
HEAD_DIM_K
=
576
;
static
constexpr
int
HEAD_DIM_K
=
576
;
...
...
csrc/sm90/decode/dense/splitkv_mla.cuh
View file @
2033d805
This diff is collapsed.
Click to expand it.
csrc/sm90/decode/dense/traits.h
View file @
2033d805
...
@@ -7,13 +7,12 @@
...
@@ -7,13 +7,12 @@
#include "config.h"
#include "config.h"
using
TMABarrier
=
cutlass
::
arch
::
ClusterTransactionBarrier
;
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
InputT_
>
template
<
typename
InputT_
,
bool
Is_causal_
>
struct
Traits
{
struct
Traits
{
using
InputT
=
InputT_
;
using
InputT
=
InputT_
;
static
constexpr
bool
Is_causal
=
Is_causal_
;
static
constexpr
int
BLOCK_SIZE_M
=
Config
::
BLOCK_SIZE_M
;
static
constexpr
int
BLOCK_SIZE_M
=
Config
::
BLOCK_SIZE_M
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
Config
::
PAGE_BLOCK_SIZE
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
Config
::
PAGE_BLOCK_SIZE
;
static
constexpr
int
HEAD_DIM_K
=
Config
::
HEAD_DIM_K
;
static
constexpr
int
HEAD_DIM_K
=
Config
::
HEAD_DIM_K
;
...
@@ -23,63 +22,105 @@ struct Traits {
...
@@ -23,63 +22,105 @@ struct Traits {
static_assert
(
std
::
is_same_v
<
InputT
,
cutlass
::
bfloat16_t
>
||
std
::
is_same_v
<
InputT
,
cutlass
::
half_t
>
);
static_assert
(
std
::
is_same_v
<
InputT
,
cutlass
::
bfloat16_t
>
||
std
::
is_same_v
<
InputT
,
cutlass
::
half_t
>
);
using
TiledMMA_QK_sQ
=
decltype
(
make_tiled_mma
(
static
constexpr
int
kBlockM
=
BLOCK_SIZE_M
;
GMMA
::
ss_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
K
>
(),
static
constexpr
int
kBlockN
=
PAGE_BLOCK_SIZE
;
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
static
constexpr
int
kHeadDim
=
HEAD_DIM_K
;
));
static
constexpr
int
kHeadDimV
=
HEAD_DIM_V
;
static
constexpr
int
kNWarps
=
4
;
using
TiledMMA_QK_rQ
=
decltype
(
make_tiled_mma
(
GMMA
::
rs_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
K
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
TiledMMA_PV_LocalP
=
decltype
(
make_tiled_mma
(
GMMA
::
rs_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_V
/
2
>
,
Int
<
PAGE_BLOCK_SIZE
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
TiledMMA_PV_RemoteP
=
decltype
(
make_tiled_mma
(
GMMA
::
ss_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_V
/
2
>
,
Int
<
PAGE_BLOCK_SIZE
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_K
>>
{}
));
using
Element
=
InputT
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
using
SmemLayoutRow
=
Layout
<
Shape
<
_128
>
,
Stride
<
_1
>>
;
using
SmemLayoutAtomK
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
{}));
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
SmemLayoutAtomK
{},
Shape
<
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
{}
Shape
<
Int
<
kBlockN
>
,
Int
<
16
*
32
>>
{}));
));
using
SmemLayoutK_place_holder
=
decltype
(
tile_to_shape
(
using
SmemLayoutV
=
decltype
(
composition
(
SmemLayoutAtomK
{},
SmemLayoutK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
15
*
32
>>
{}));
make_layout
(
Shape
<
Int
<
HEAD_DIM_V
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{},
GenRowMajor
{})
using
SmemLayoutAtomV
=
SmemLayoutAtomK
;
));
// A transposed version of SmemLayoutK
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{}));
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
4
*
16
*
16
>>
,
Stride
<
Int
<
1
>>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
4
*
16
*
16
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
kHeadDimV
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
// #if defined(__gfx936__) || defined(__gfx938__)
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NT
>
>
;
using
TiledMma
=
TiledMMA
<
MMA_Atom_Arch
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
// #elif defined(__gfx928__)
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
// >;
// using TiledMma = TiledMMA<
// MMA_Atom_Arch,
// Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
// ValLayoutMNK>;
// #endif
using
MMA_Atom_Arch_16x32
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
TiledMma_O
=
TiledMMA
<
MMA_Atom_Arch_16x32
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
GmemLayoutAtomQ
=
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_8
,
_1
>>
;
using
GmemTiledCopyQ
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomQ
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
SmemLayoutP0
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{}
));
using
rP0Layout
=
decltype
(
layout
(
partition_fragment_C
(
TiledMMA_QK_sQ
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{}
)));
struct
SharedMemoryPlan
{
struct
SharedMemoryPlan
{
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutQ
>>
smem_sQ
;
union
{
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutK
>>
smem_sK0
;
struct
{
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutK
>>
smem_sK1
;
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutV
>>
smem_v
;
// Double buffer
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutP0
>>
smem_sP0
;
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sM
;
};
cute
::
array_aligned
<
float
,
2
*
BLOCK_SIZE_M
>
sL_reduction_wksp
;
struct
{
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sScale0
;
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutK_place_holder
>>
smem_temp
;
// Double buffer
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sScale1
;
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
TMABarrier
barriers_K0
[
HEAD_DIM_K
/
64
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_sum
;
TMABarrier
barriers_K1
[
HEAD_DIM_K
/
64
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
TMABarrier
barrier_Q
;
};
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
};
};
};
};
};
};
...
...
csrc/utils.h
View file @
2033d805
...
@@ -88,6 +88,18 @@ struct RingBufferState {
...
@@ -88,6 +88,18 @@ struct RingBufferState {
}
}
};
};
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
namespace
flash
{
namespace
flash
{
using
namespace
cute
;
using
namespace
cute
;
...
@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla(
...
@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla(
}
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
mma_layout
=
false
,
bool
use_asm
=
false
,
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
uint128_t
&
dst
,
int
k_idx_
,
const
int
row_stride
,
int
offset_k
,
const
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
constexpr
int
elements_per_thread
=
8
;
if
constexpr
(
mma_layout
)
{
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
}
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
*
reinterpret_cast
<
uint128_t
*>
(
&
res
);
}
}
else
{
uint32x4_t
global_addr
=
{
0
};
*
(
uint64_t
*
)
&
global_addr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
/
4
;
int
col
=
lane
%
4
;
int
row_offset
=
row
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
}
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
*
reinterpret_cast
<
uint128_t
*>
(
&
res
);
}
}
}
template
<
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_to_tensor
(
const
uint128_t
&
src
,
Tensor
<
SrcEngine
,
SrcLayout
>
&
dst
,
int
k_idx
)
{
uint128_t
*
d
=
reinterpret_cast
<
uint128_t
*>
(
&
dst
(
0
,
0
,
k_idx
));
d
[
0
]
=
src
;
}
template
<
class
TiledMma
,
class
TiledMma_O
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__forceinline__
__device__
auto
convert_layout_acc_Aregs_dense
(
const
TiledMma
&
tiled_mma
,
const
TiledMma_O
&
tiled_mma_o
,
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
const
&
sAcc
)
{
using
Value_type
=
typename
Engine0
::
value_type
;
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
// __fp16 *smem_ptr =
// sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
0
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
1
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
1
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
2
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
2
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
3
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
3
,
0
,
0
);
__syncthreads
();
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
16
>
,
Int
<
64
>>
{}));
Tensor
sP_tmp
=
make_tensor
(
sAcc
.
data
(),
SmemLayoutP
{});
auto
thr_mma
=
tiled_mma_o
.
get_thread_slice
(
tid
);
Tensor
tSrACC
=
thr_mma
.
partition_fragment_A
(
sP_tmp
);
tSrACC
(
0
,
0
,
0
)
=
sAcc
(
tid
*
8
+
0
);
tSrACC
(
1
,
0
,
0
)
=
sAcc
(
tid
*
8
+
1
);
tSrACC
(
2
,
0
,
0
)
=
sAcc
(
tid
*
8
+
2
);
tSrACC
(
3
,
0
,
0
)
=
sAcc
(
tid
*
8
+
3
);
tSrACC
(
0
,
0
,
1
)
=
sAcc
(
tid
*
8
+
0
+
4
);
tSrACC
(
1
,
0
,
1
)
=
sAcc
(
tid
*
8
+
1
+
4
);
tSrACC
(
2
,
0
,
1
)
=
sAcc
(
tid
*
8
+
2
+
4
);
tSrACC
(
3
,
0
,
1
)
=
sAcc
(
tid
*
8
+
3
+
4
);
tSrACC
(
0
,
0
,
2
)
=
sAcc
(
tid
*
8
+
0
+
16
*
32
);
tSrACC
(
1
,
0
,
2
)
=
sAcc
(
tid
*
8
+
1
+
16
*
32
);
tSrACC
(
2
,
0
,
2
)
=
sAcc
(
tid
*
8
+
2
+
16
*
32
);
tSrACC
(
3
,
0
,
2
)
=
sAcc
(
tid
*
8
+
3
+
16
*
32
);
tSrACC
(
0
,
0
,
3
)
=
sAcc
(
tid
*
8
+
0
+
4
+
16
*
32
);
tSrACC
(
1
,
0
,
3
)
=
sAcc
(
tid
*
8
+
1
+
4
+
16
*
32
);
tSrACC
(
2
,
0
,
3
)
=
sAcc
(
tid
*
8
+
2
+
4
+
16
*
32
);
tSrACC
(
3
,
0
,
3
)
=
sAcc
(
tid
*
8
+
3
+
4
+
16
*
32
);
return
tSrACC
;
}
}
}
\ No newline at end of file
tests/test_flash_mla_dense_decoding.py
View file @
2033d805
...
@@ -223,9 +223,10 @@ def main(torch_dtype):
...
@@ -223,9 +223,10 @@ def main(torch_dtype):
]
]
performance_cases
=
[
performance_cases
=
[
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
test_performance
=
True
)
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
h_q
=
h_q
,
test_performance
=
True
)
for
is_causal
in
[
False
,
True
]
for
is_causal
in
[
False
,
True
]
for
s_q
in
[
1
,
2
]
for
s_q
in
[
1
,
2
]
for
h_q
in
[
16
,
128
]
for
s_k
in
[
4096
,
8192
,
16384
,
32768
]
for
s_k
in
[
4096
,
8192
,
16384
,
32768
]
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment