Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
00fe0752
Commit
00fe0752
authored
Jan 24, 2025
by
Qianfeng Zhang
Browse files
Use LDS as intermediary stop when loading Q from global memory for qr_ks_vs_async pipeline
parent
34157f26
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
134 additions
and
10 deletions
+134
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+39
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+95
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
00fe0752
...
...
@@ -164,9 +164,21 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
auto
original_q
=
load_tile
(
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
smem_ptr
);
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
...
...
@@ -184,14 +196,6 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
...
...
@@ -256,6 +260,8 @@ struct BlockFmhaPipelineQRKSVSAsync
// load
auto
k_tile
=
load_tile
(
k_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -272,8 +278,27 @@ struct BlockFmhaPipelineQRKSVSAsync
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// store Q into LDS
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_store
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
store_tile
(
q_lds_window_for_store
,
original_q
);
__builtin_amdgcn_sched_barrier
(
0
);
// load Q from LDS
auto
q_lds_window_for_load
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
block_sync_lds
();
auto
q
=
load_tile
(
q_lds_window_for_load
);
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
__builtin_amdgcn_sched_barrier
(
0
);
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
...
...
@@ -281,6 +306,10 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// ensure loading of Q from LDS completely done
block_sync_lds
();
do
{
store_tile
(
k_lds_window
,
k_tile
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
00fe0752
...
...
@@ -13,6 +13,50 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */
true
,
/* NumPrefetchV = */
2
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
...
...
@@ -82,6 +126,57 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
QDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kKPack
=
min
(
ElemPerThread
,
GetSmemKPackQ
<
Problem
>
());
constexpr
auto
q_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
q_lds_block_desc
=
transform_tensor_descriptor
(
q_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
q_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// assume Q can reuse the shared memory with K or V
return
max
(
GetSmemSizeQ
<
Problem
>
(),
GetSmemSizeK
<
Problem
>
()
+
GetSmemSizeV
<
Problem
>
())
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment