Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4cc514f8
Commit
4cc514f8
authored
Aug 07, 2024
by
danyao12
Browse files
fix unpadded lse issue in fwd splitkv
parent
15758862
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
30 deletions
+21
-30
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+14
-15
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+7
-15
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
4cc514f8
...
...
@@ -99,10 +99,9 @@ struct FmhaFwdSplitKVCombineKernel
struct
CommonLSEKargs
{
void
*
lse_ptr
=
nullptr
;
ck_tile
::
index_t
nhead_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse_acc
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
void
*
lse_ptr
=
nullptr
;
ck_tile
::
index_t
nhead_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
struct
Fp8StaticQuantKargs
...
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -171,14 +171,14 @@ struct FmhaFwdSplitKVCombineKernel
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
batch_stride_o
,
batch_stride_lse_acc
};
if
constexpr
(
kStoreLSE
)
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -282,12 +282,12 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse_acc
=
query_start
;
batch_offset_lse
=
query_start
;
batch_offset_lse
=
query_start
;
}
// get real # queries & # keys under group mode
...
...
@@ -303,12 +303,11 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
4cc514f8
...
...
@@ -47,7 +47,6 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
...
@@ -520,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -538,10 +538,6 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse_acc
=
query_start
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -566,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
@@ -578,11 +575,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
}
}
// for simplicity, batch stride we just modify the pointer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment