Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a94ac4bb
Commit
a94ac4bb
authored
Feb 02, 2025
by
Qianfeng Zhang
Browse files
Use QLoadOnce == false for qr_ks_vs_async pipeline
parent
475c0d2c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
11 deletions
+68
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+0
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+1
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+67
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+0
-3
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
a94ac4bb
...
...
@@ -35,9 +35,6 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
a94ac4bb
...
...
@@ -31,12 +31,9 @@ struct BlockFmhaPipelineQRKSVSAsync
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
true
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
a94ac4bb
...
...
@@ -9,7 +9,7 @@
namespace
ck_tile
{
struct
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
tru
e
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
fals
e
,
/* AsyncCopy = */
true
,
/* NumPrefetchV = */
2
>
{
...
...
@@ -22,6 +22,72 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
Problem
::
BlockFmhaShape
::
kK0
>();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
a94ac4bb
...
...
@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kQLoadOnce
=
false
;
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
false
;
static_assert
(
kKLoadOnce
==
Policy
::
KLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment