Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6ad5d053
Commit
6ad5d053
authored
Oct 14, 2024
by
Po Yen, Chen
Browse files
Duplicate policy code instead of inherit from old one
parent
021c7e84
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
732 additions
and
89 deletions
+732
-89
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
+732
-89
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
View file @
6ad5d053
...
@@ -10,34 +10,186 @@ namespace ck_tile {
...
@@ -10,34 +10,186 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
struct
BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
{
{
// start copy from BlockFmhaPipelineQXCustomPolicy<true>
static
constexpr
bool
QLoadOnce
=
true
;
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
0
;
}
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// TODO: GetAlignment*() currently didn't consider if need padding or not
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// so in pipeline still need check padding requirement
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
);
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
swizzle_factor
>
{};
}
// TODO - bf8_t
}();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// end copy from BlockFmhaPipelineQXCustomPolicy<true>
// start copy from BlockFmhaPipelineQXKSVSCustomPolicy
static
constexpr
bool
AsyncCopyK
=
false
;
static
constexpr
bool
AsyncCopyV
=
false
;
// TODO: this not supported yet
static
constexpr
index_t
NumPrefetchK
=
1
;
static
constexpr
index_t
NumPrefetchV
=
1
;
template
<
index_t
k_prefetches_
,
index_t
v_prefetches_
,
index_t
k_loops_
,
index_t
v_loops_
>
struct
LdsBufferSequence
{
static
constexpr
auto
Make
()
{
return
transform_sequences
(
[
&
](
auto
i
)
{
if
(
i
<
k_loops_
)
return
i
%
k_prefetches_
;
return
(
i
-
k_loops_
)
%
v_prefetches_
;
},
typename
arithmetic_sequence_gen
<
0
,
k_loops_
+
v_loops_
,
1
>::
type
{});
};
using
type
=
remove_cvref_t
<
decltype
(
Make
())
>
;
};
// clang-format off
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
4
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
4
,
2
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLdsBufferSequence
()
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackK
()
{
// TODO: this is for 3d layout
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
return
16
/
sizeof
(
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentK
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
if
constexpr
(
AsyncCopyK
)
{
return
4
/
sizeof
(
KDataType
);
}
else
{
return
16
/
sizeof
(
KDataType
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackV
()
{
// TODO: this is for 3d layout
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
return
16
/
sizeof
(
VDataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -47,9 +199,12 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -47,9 +199,12 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
// TODO: not correct!
// TODO: not correct!
...
@@ -64,6 +219,455 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -64,6 +219,455 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentBias
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSingleSmemElementSpaceSize
()
{
// this function assume K/V can share smem
constexpr
index_t
SingleKSize
=
[
&
]()
{
if
constexpr
(
!
AsyncCopyK
)
{
return
MakeKLdsStoreBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
return
NumIssues
*
NumWarps
*
(
warpSize
*
KVector
+
kPad
);
}
}();
constexpr
index_t
SingleVSize
=
[
&
]()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
}();
return
max
(
SingleKSize
,
SingleVSize
);
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#if 0 // [POYENC] disabled for now
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
warpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<warpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
#endif
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
1
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#else
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
()
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr
index_t
BufferSize
=
GetSingleSmemElementSpaceSize
<
Problem
>
();
// max(SingleKSize, SingleVSize);
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetchK
>
{},
// num_buffers
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
kKPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
BufferSize
>
{},
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
kKPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetchK
>
{},
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
3
,
2
>
{},
sequence
<
4
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
#endif
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
constexpr
auto
v_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetchV
>
{},
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
)
>
{},
number
<
PixelsPerRow
+
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
v_lds_block_desc
=
transform_tensor_descriptor
(
v_lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetchV
>
{},
number
<
kNPerBlock
/
NPerRow
>
{},
number
<
NPerRow
>
{})),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
0
,
2
,
3
>
{},
sequence
<
1
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
v_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeKV
()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
constexpr
index_t
single_smem_size
=
GetSingleSmemElementSpaceSize
<
Problem
>
()
*
sizeof
(
typename
Problem
::
KDataType
);
return
GetSmemSizeQ
<
Problem
>
()
+
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
if
constexpr
(
AsyncCopyK
)
{
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
else
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
}
}
// this method is only available when Problem::kHasDropout is present
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
Problem
::
kHasDropout
),
bool
>
,
ck_tile
::
index_t
>
GetSmemSizeDropout
(
int
)
{
if
constexpr
(
Problem
::
kHasDropout
)
{
constexpr
auto
gemm_0
=
GetQKBlockGemm
<
Problem
>
();
constexpr
auto
config
=
decltype
(
gemm_0
)
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
}
else
{
return
0
;
}
}
// fallback version if Problem::kHasDropout is not exist
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
(...)
{
return
0
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
if
constexpr
(
!
AsyncCopyK
)
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; // [POYENC] old tile
// size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
static_assert
(
warpSize
*
KVector
>=
kKPerBlock
&&
warpSize
*
KVector
%
kKPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
kNPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
kNPerBlock
*
kKPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
index_t
N0
=
NumIssues
;
constexpr
index_t
N1
=
LaneGroups
;
constexpr
index_t
N2
=
NumWarps
;
constexpr
index_t
K0
=
LanesPerK
;
constexpr
index_t
K1
=
KVector
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeVDramTileDistribution
()
{
{
...
@@ -71,7 +675,8 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -71,7 +675,8 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -131,40 +736,48 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -131,40 +736,48 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
}
}
}
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
k_lds_block_desc_0
,
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// Construct C-Block-HostTensor
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledVRegBlockDescriptor
2
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledVRegBlockDescriptor
()
{
{
// This descriptor only used when V layout is seqlen * hdim
// This descriptor only used when V layout is seqlen * hdim
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
::
VLayout
>
;
static_assert
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
static_assert
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
...
@@ -174,7 +787,6 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -174,7 +787,6 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackV
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
static_assert
((
get_warp_size
()
%
(
K2
*
N0
)
==
0
));
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
...
@@ -204,6 +816,81 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -204,6 +816,81 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
VDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
OaccDataType
,
float
>
)
{
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<>
{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
}
else
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
>
{};
}
}();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
// end copy from BlockFmhaPipelineQXKSVSCustomPolicy
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsStoreBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsStoreBlockDescriptor
()
...
@@ -215,7 +902,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -215,7 +902,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
// [POYENC] updated tile size
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
...
@@ -244,50 +931,6 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
...
@@ -244,50 +931,6 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return
v_lds_block_desc
;
return
v_lds_block_desc
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSingleSmemElementSpaceSize
()
{
// this function assume K/V can share smem
constexpr
index_t
SingleKSize
=
[
&
]()
{
return
MakeKLdsStoreBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
constexpr
index_t
SingleVSize
=
[
&
]()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
}();
return
max
(
SingleKSize
,
SingleVSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeKV
()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
constexpr
index_t
single_smem_size
=
GetSingleSmemElementSpaceSize
<
Problem
>
()
*
sizeof
(
typename
Problem
::
KDataType
);
return
QXPolicy
::
template
GetSmemSizeQ
<
Problem
>()
+
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment