Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
119dd2ac
Commit
119dd2ac
authored
Jan 31, 2025
by
Qianfeng Zhang
Browse files
Let V reuse the LDS of K
parent
512eeecb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
175 deletions
+12
-175
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+8
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+4
-165
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
119dd2ac
...
@@ -175,12 +175,11 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -175,12 +175,11 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
NumKLdsBuffers
>=
2
);
static_assert
(
NumKLdsBuffers
>=
2
);
static_assert
(
NumVLdsBuffers
>=
2
);
static_assert
(
NumVLdsBuffers
>=
2
);
auto
q_dram_window
=
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
auto
q
=
load_tile
(
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -193,8 +192,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -193,8 +192,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// V tile in LDS
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
GetSmemSizeK
<
Problem
>()),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
...
@@ -296,7 +294,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -296,7 +294,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
{
if
(
i_total_loops
==
0
)
// executed by fist iteration
if
(
i_total_loops
==
0
)
// executed by fist iteration
{
{
if
(
i
_total_loop
s
<
num_total_loop
)
if
(
num
_total_loop
>
1
)
// there are multiple iterations
{
{
auto
k_lds_window_tmp
=
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
...
@@ -341,7 +339,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -341,7 +339,7 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window_tmp
);
k_lds_window_tmp
);
}
}
else
else
// there is only single iteration
{
{
auto
k_lds_window_tmp
=
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
119dd2ac
...
@@ -13,174 +13,13 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -13,174 +13,13 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */
true
,
/* AsyncCopy = */
true
,
/* NumPrefetchV = */
2
>
/* NumPrefetchV = */
2
>
{
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
Problem::BlockFmhaShape::kSubQKHeaddim)
? Problem::BlockFmhaShape::kSubQKHeaddim
: Problem::BlockFmhaShape::kK0;
using GemmProblem = BlockGemmProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
BlockGemmK>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename
Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
QDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kKPack
=
min
(
ElemPerThread
,
GetSmemKPackQ
<
Problem
>
());
constexpr
auto
q_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
q_lds_block_desc
=
transform_tensor_descriptor
(
q_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
q_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
QDataType
);
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
// assume Q can reuse the shared memory with K or V
// assume V can reuse the shared memory by K
// assume Dropout can reuse the shared memory with V
// assume Dropout can reuse the shared memory by V
return
max
(
GetSmemSizeQ
<
Problem
>
(),
return
max
(
GetSmemSizeK
<
Problem
>
(),
GetSmemSizeK
<
Problem
>
()
+
max
(
GetSmemSizeV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
)));
max
(
GetSmemSizeV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
)));
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment