Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
38eb3bba
"...composable_kernel_rocm.git" did not exist on "1da802bdf2cc56defeaacf73c63461fe3c4cf692"
Commit
38eb3bba
authored
Nov 17, 2023
by
letaoqin
Browse files
Q load once when Q_k <= 64
parent
2f93e26f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
41 deletions
+43
-41
example/52_flash_atten_bias/batched_gemm_multihead_attention_infer.cpp
...ash_atten_bias/batched_gemm_multihead_attention_infer.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+42
-40
No files found.
example/52_flash_atten_bias/batched_gemm_multihead_attention_infer.cpp
View file @
38eb3bba
...
...
@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------|
Gemm1
*/
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
38eb3bba
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp"
namespace
ck
{
...
...
@@ -95,14 +96,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
Q_d
=
64
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
...
...
@@ -232,12 +234,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_
size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
softmax_bytes_end
=
SharedMemTrait
::
reduction_space_
offset
*
sizeof
(
FloatAB
)
+
SharedMemTrait
::
reduction_space_size_aligned
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
...
...
@@ -498,7 +499,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
Number
<
Q_d
/
KPerBlock
>
{},
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
...
...
@@ -506,13 +508,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
b1_block_space_offset
=
a_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
static
constexpr
auto
reduction_space_offset
=
a_block_space_size_aligned
.
value
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -611,7 +613,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
@@ -673,11 +674,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto
acc_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
// LDS allocation for B: be careful of alignment
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
...
...
@@ -689,11 +686,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
auto
Q_k
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1r1
<
NumGemmKPrefetchStage
>
{};
/*GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();*/
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
@@ -843,7 +842,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// Blockwise softmax
//
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
*
sizeof
(
FloatAB
)
/
sizeof
(
FloatGemmAcc
),
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 8D thread cluster
...
...
@@ -940,21 +940,23 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue
;
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
FloatAB
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
,
p_shared
,
gemm1_k_block_outer_index
==
0
&&
Q_k
<=
64
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment