Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
de6dd79f
Commit
de6dd79f
authored
Dec 19, 2024
by
Po Yen Chen
Browse files
Fix compilation errors
parent
232864b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
164 additions
and
118 deletions
+164
-118
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+164
-118
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
View file @
de6dd79f
...
@@ -16,20 +16,19 @@ namespace ck_tile {
...
@@ -16,20 +16,19 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineQRKSVSAsync
struct
BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
RandValOutputDataType
=
remove_cvref_t
<
typename
Problem
::
RandValOutputDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -51,13 +50,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -51,13 +50,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -69,7 +69,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -69,7 +69,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
else
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
@@ -83,7 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -83,7 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
else
else
{
{
// minimize occupancy
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)
{
{
return
1
;
return
1
;
}
}
...
@@ -119,24 +121,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -119,24 +121,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static
constexpr
const
char
*
name
=
"qr_async"
;
static
constexpr
const
char
*
name
=
"qr_async"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
VDramBlockWindowTmp
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
LSE
acc
ElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
OAccElementFunction
,
...
@@ -144,35 +145,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -144,35 +145,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KElementFunction
&
/*k_element_func*/
,
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
DropoutType
&
dropout
)
const
void
*
smem_ptr
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Dram
Block
WindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
K
Page
Block
Navigator
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Dram
Block
WindowTmp
::
DataType
>>
,
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
V
Page
Block
Navigator
::
DataType
>>
,
"wrong!"
);
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kK0
==
KDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
0
>
{}]
&&
kN1
==
VDramBlockWindow
L
engths
{}
[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindow
Tmp
{}.
get_window_l
engths
()
[
number
<
1
>
{}]
&&
kK1
==
VDramBlockWindow
L
engths
{}
[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
...
@@ -264,24 +268,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -264,24 +268,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
const
auto
[
logical_seqlen_k_start
,
logical_seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if no work to do
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
{
if
(
num_total_loop
<=
0
)
const
index_t
logical_num_total_loop
=
integer_divide_ceil
(
logical_seqlen_k_end
-
logical_seqlen_k_start
,
kN0
);
if
(
logical_num_total_loop
<=
0
)
{
{
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
auto
lse
=
auto
lse
_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
set_tile
(
lse
_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// otherwise will have compute error(maybe compiler bug?)
...
@@ -292,23 +297,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -292,23 +297,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
}
auto
k_dram_block_window
=
const
index_t
physical_seqlen_k_start
=
logical_seqlen_k_start
+
kv_l2p_offset
;
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
const
index_t
physical_seqlen_k_end
=
logical_seqlen_k_end
+
kv_l2p_offset
;
k_dram_block_window_tmp
.
get_window_lengths
(),
// make sure the first tile is completely located in page-block (page-block size should be
{
seqlen_k_start
,
0
});
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const
index_t
aligned_physical_seqlen_k_start
=
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
]
{
if
constexpr
(
kIsPagedKV
)
{
return
kN0
*
integer_divide_floor
(
physical_seqlen_k_start_
,
kN0
);
}
else
{
return
physical_seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
physical_seqlen_k_end
-
aligned_physical_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
aligned_physical_seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
,
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// load
k_dram_window
.
init_raw
();
k_dram_window
.
init_raw
();
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)))
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)))
return
bool_constant
<
true
>
{};
return
bool_constant
<
true
>
{};
else
else
return
bool_constant
<
false
>
{};
return
bool_constant
<
false
>
{};
...
@@ -318,17 +338,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -318,17 +338,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
auto
bias_dram_window
=
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
logical_seqlen_k_start
-
(
physical_seqlen_k_start
-
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
v_dram_block_window_lengths
,
{
0
,
aligned_physical_seqlen_k_start
},
// TODO: hdim split?
auto
v_dram_window
=
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
// prefetch K tile
async_load_tile_raw
(
async_load_tile_raw
(
...
@@ -438,7 +456,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -438,7 +456,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
// position_encoding accept only logical coordinates, do conversion here
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
-
kv_l2p_offset
);
});
});
});
});
}
}
...
@@ -450,9 +469,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -450,9 +469,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
#endif
#endif
}
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in first/last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
}
else
{
return
physical_seqlen_k_end_
<=
col
;
}
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
// mask accept only logical coordinates, do conversion here
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kM0
>
{},
...
@@ -464,7 +509,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -464,7 +509,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
return
mask
.
IsOutOfBound
(
row
,
col
-
kv_l2p_offset
);
});
});
}
}
}
}
...
@@ -513,9 +558,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -513,9 +558,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if
constexpr
(
k1_loops
>
1
)
if
constexpr
(
k1_loops
>
1
)
{
{
move_tile_window
(
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
v_dram_window
,
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
}
...
@@ -595,17 +639,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -595,17 +639,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
});
});
});
});
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
[
&
]()
{
const
auto
p
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
...
@@ -618,11 +651,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -618,11 +651,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
// STAGE 3, KV gemm
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
if
constexpr
(
k1_loops
>
1
)
{
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
{
v_buf
=
load_tile
(
v_buf
=
load_tile
(
v_dram_window_
,
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
}
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
...
@@ -656,14 +692,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -656,14 +692,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
});
});
}
}
i_total_loops
++
;
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
if
(
i_total_loops
<
num_total_loop
)
{
{
// move K tile windows
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
if
constexpr
(
k1_loops
>=
2
&&
if
constexpr
(
k1_loops
>=
2
&&
...
@@ -689,30 +728,30 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -689,30 +728,30 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}
}
}
while
(
i_total_loops
<
num_total_loop
);
}
while
(
i_total_loops
<
num_total_loop
);
// store lse
// store lse
acc
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
auto
lse
_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_spans
=
decltype
(
lse
)
::
get_distributed_spans
();
constexpr
auto
lse_
acc_
spans
=
decltype
(
lse
_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
sweep_tile_span
(
lse_
acc_
spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
lse
_acc
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
else
else
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
lse
_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
#else
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
lse
_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
#endif
});
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
store_tile
(
lse_
acc_
dram_window_tmp
,
tile_elementwise_in
(
lse_
acc_
element_func
,
lse
_acc
));
}
}
// finally, O
// finally, O
...
@@ -740,44 +779,51 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -740,44 +779,51 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
VDramBlockWindowTmp
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
DropoutType
&
dropout
)
const
void
*
smem_ptr
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
k_dram_block_window_tmp
,
k_dram_block_window_lengths
,
k_page_block_navigator
,
identity
{},
identity
{},
v_dram_block_window_tmp
,
v_dram_block_window_lengths
,
v_page_block_navigator
,
identity
{},
identity
{},
bias_dram_block_window_tmp
,
bias_dram_block_window_tmp
,
identity
{},
identity
{},
randval_dram_block_window_tmp
,
lse_acc_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
mask
,
position_encoding
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
,
kv_l2p_offset
,
dropout
);
smem_ptr
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment