Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
199f7f71
Commit
199f7f71
authored
Sep 01, 2024
by
carlushuang
Browse files
modify moe
parent
33ceea62
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2452 additions
and
362 deletions
+2452
-362
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
...tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
+15
-0
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
.../ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
+292
-78
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
...e/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
+130
-250
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp
...include/ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp
+410
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp
..._tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp
+27
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
...ile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
+464
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
.../fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
+668
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
...ile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
+1
-11
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp
+116
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
...clude/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
+23
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
.../ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
+15
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+115
-12
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+14
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+25
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+20
-5
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+26
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+17
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+68
-0
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+5
-5
No files found.
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_
traits
.hpp
→
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_
permute_enum
.hpp
View file @
199f7f71
...
...
@@ -3,16 +3,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
enum
class
FusedMoePermuteStyle
enum
class
FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
...
...
@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
};
template
<
bool
DownPreShuffled_
=
false
,
FusedMoePermuteStyle
PermuteStyle_
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
FusedMoeTraits
{
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
FusedMoePermuteStyle
PermuteStyle
=
PermuteStyle_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
}
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
View file @
199f7f71
...
...
@@ -14,7 +14,7 @@ namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVSAsync
struct
FusedMoePipeline
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
...
@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
ScaleDataType
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
FusedMoeTileShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
#endif
static
constexpr
index_t
kBlockM_0
=
FusedMoeTileShape
::
kBlockM_0
;
static
constexpr
index_t
kBlockN_0
=
FusedMoeTileShape
::
kBlockN_0
;
static
constexpr
index_t
kBlockK_0
=
FusedMoeTileShape
::
kBlockK_0
;
static
constexpr
index_t
kWarpM_0
=
FusedMoeTileShape
::
kWarpM_0
;
static
constexpr
index_t
kWarpN_0
=
FusedMoeTileShape
::
kWarpN_0
;
static
constexpr
index_t
kWarpK_0
=
FusedMoeTileShape
::
kWarpK_0
;
static
constexpr
index_t
kBlockWarpsM_0
=
FusedMoeTileShape
::
kBlockWarpsM_0
;
static
constexpr
index_t
kBlockWarpsN_0
=
FusedMoeTileShape
::
kBlockWarpsN_0
;
static
constexpr
index_t
kBlockWarpsK_0
=
FusedMoeTileShape
::
kBlockWarpsK_0
;
static
constexpr
index_t
kSubBlockM_0
=
FusedMoeTileShape
::
kSubBlockM_0
;
static
constexpr
index_t
kSubBlockN_0
=
FusedMoeTileShape
::
kSubBlockN_0
;
static
constexpr
index_t
kSubBlockK_0
=
FusedMoeTileShape
::
kSubBlockK_0
;
static
constexpr
index_t
kWarpRepeatM_0
=
FusedMoeTileShape
::
kWarpRepeatM_0
;
static
constexpr
index_t
kWarpRepeatN_0
=
FusedMoeTileShape
::
kWarpRepeatN_0
;
static
constexpr
index_t
kWarpRepeatK_0
=
FusedMoeTileShape
::
kWarpRepeatK_0
;
static
constexpr
index_t
kBlockM_1
=
FusedMoeTileShape
::
kBlockM_1
;
static
constexpr
index_t
kBlockN_1
=
FusedMoeTileShape
::
kBlockN_1
;
static
constexpr
index_t
kBlockK_1
=
FusedMoeTileShape
::
kBlockK_1
;
static
constexpr
index_t
kWarpM_1
=
FusedMoeTileShape
::
kWarpM_1
;
static
constexpr
index_t
kWarpN_1
=
FusedMoeTileShape
::
kWarpN_1
;
static
constexpr
index_t
kWarpK_1
=
FusedMoeTileShape
::
kWarpK_1
;
static
constexpr
index_t
kBlockWarpsM_1
=
FusedMoeTileShape
::
kBlockWarpsM_1
;
static
constexpr
index_t
kBlockWarpsN_1
=
FusedMoeTileShape
::
kBlockWarpsN_1
;
static
constexpr
index_t
kBlockWarpsK_1
=
FusedMoeTileShape
::
kBlockWarpsK_1
;
static
constexpr
index_t
kSubBlockM_1
=
FusedMoeTileShape
::
kSubBlockM_1
;
static
constexpr
index_t
kSubBlockN_1
=
FusedMoeTileShape
::
kSubBlockN_1
;
static
constexpr
index_t
kSubBlockK_1
=
FusedMoeTileShape
::
kSubBlockK_1
;
static
constexpr
index_t
kWarpRepeatM_1
=
FusedMoeTileShape
::
kWarpRepeatM_1
;
static
constexpr
index_t
kWarpRepeatN_1
=
FusedMoeTileShape
::
kWarpRepeatN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
FusedMoeTileShape
::
kWarpRepeatK_1
;
using
MBlockType
=
decltype
(
GetMatrixCoreSwizzledBlockTIle_0
<
Problem
>
());
static
constexpr
index_t
kBlockNr_0
=
MBlockType
{}
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kBlockKr_0
=
MBlockType
{}
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kBlockWaveFlatten
=
MBlockType
{}
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
...
@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync
else
{
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)
{
return
1
;
}
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
return
2
;
}
}();
...
...
@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync
o_gtile_window_tmp
.
get_window_lengths
(),
o_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_gtile_window
));
using
u_thread_type
=
decltype
(
load_tile
(
u_gtile_window
));
using
d_thread_type
=
decltype
(
load_tile
(
d_gtile_window
));
const
index_t
loops_0
=
(
dim_size
+
kBlockK_0
-
1
)
/
kBlockK_0
;
const
index_t
loops_1
=
(
dim_size
+
kBlockN_1
-
1
)
/
kBlockN_1
;
// auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
// issues_warps_lanes
auto
a_sst_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// issues_warps_lanes
auto
a_sst_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// m*k
auto
a_sld_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
g_thread_type
g_tile
[
2
];
using
WarpGemm0
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm1
=
Policy
::
GetWarpGemm1
<
Problem
>
();
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// TODO: N fist, M next
const
index_t
i_mwarp_0
=
get_warp_id
()
/
kBlockWarpsN_0
;
// create and pre-cache a warp-window
auto
make_a_warp_windows
=
[
&
](
auto
a_sld_
)
{
// construct A-warp-window
auto
warp_window
=
make_tile_window
(
a_sld_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm0
::
kM
>
{},
number
<
WarpGemm0
::
kK
>
{}),
a_sld_
.
get_window_origin
()
+
multi_index
<
2
>
{
i_mwarp_0
*
WarpGemm0
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm0
::
AWarpDstrEncoding
{}));
statically_indexed_array
<
statically_indexed_array
<
decltype
(
warp_window
),
kWarpRepeatK_0
>
,
kWarpRepeatM_0
>
ws
;
// pre-cache the warp windows
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m_iter
)
{
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k_iter
)
{
ws
(
i_m_iter
)(
i_k_iter
)
=
warp_window
;
move_tile_window
(
ws
(
i_m_iter
)(
i_k_iter
),
{
i_m_iter
*
NPerBlockPerIter
,
i_k_iter
*
KPerBlockPerIter
});
});
});
return
ws
;
};
auto
a_warp_windows_0
=
make_a_warp_windows
(
a_sld_0
);
auto
a_warp_windows_1
=
make_a_warp_windows
(
a_sld_1
);
constexpr
auto
true_v
=
bool_constant
<
true
>
{};
constexpr
auto
false_v
=
bool_constant
<
false
>
{};
auto
do_load_a0
=
[
&
](
auto
&
a_store_
,
auto
move_
)
{
async_load_tile
(
a_store_
,
a_gtile_window
);
if
constexpr
(
move_
)
move_tile_window
(
a_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockK_0
>
{}});
};
auto
do_load_b0
=
[
&
](
auto
&
g_tile_
,
auto
&
u_tile_
,
auto
move_
)
{
g_tile_
=
load_tile
(
g_gtile_window
);
u_tile_
=
load_tile
(
u_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
g_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
move_tile_window
(
u_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
auto
do_load_b1
=
[
&
](
auto
&
d_tile_
,
auto
move_
)
{
d_tile_
=
load_tile
(
d_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
d_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
// using AWarpTensor = typename decltype(warp_gemm_0)::AWarpTensor{};
// using CWarpTensor =
auto
acc_g
=
MakeCBlockTile_Gemm0
<
Problem
>
();
auto
acc_u
=
MakeCBlockTile_Gemm0
<
Problem
>
();
// async_load_tile(a_sst_0, a_gtile_window); move_tile_window(a_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); g_tile[0] = load_tile(g_gtile_window);
// move_tile_window(g_gtile_window, {number<0>{}, number<kBlockK_0>{}}); u_tile[0] =
// load_tile(u_gtile_window); move_tile_window(u_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); async_load_tile(a_sst_1, a_gtile_window);
// move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}}); g_tile[1] =
// load_tile(g_gtile_window); move_tile_window(g_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); u_tile[1] = load_tile(u_gtile_window);
// move_tile_window(u_gtile_window, {number<0>{}, number<kBlockK_0>{}});
auto
do_gemm_0
=
[
&
](
auto
&
acc_g_
,
auto
&
acc_u_
,
auto
&
a_windows_
,
auto
&
g_tile_
,
auto
&
u_tile_
)
{
// as_br (asmem, breg)
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m
)
{
const
auto
w_a
=
load_tile
(
a_windows_
(
i_m
)(
i_k
));
static_for
<
0
,
kWarpRepeatN_0
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_g
=
get_slice_tile
(
acc_g_
,
beg_acc
,
end_acc
);
auto
w_acc_u
=
get_slice_tile
(
acc_u_
,
beg_acc
,
end_acc
);
auto
w_g
=
get_slice_tile
(
g_tile_
,
beg_b
,
end_b
);
auto
w_u
=
get_slice_tile
(
u_tile_
,
beg_b
,
end_b
);
warp_gemm_0
(
w_acc_g
,
w_a
,
w_g
);
warp_gemm_0
(
w_acc_u
,
w_a
,
w_u
);
set_slice_tile
(
acc_g_
,
w_acc_g
,
beg_acc
,
end_acc
);
set_slice_tile
(
acc_u_
,
w_acc_u
,
beg_acc
,
end_acc
);
});
});
});
};
auto
do_gemm_1
=
[
&
](
auto
&
acc_d_
,
auto
&
a_tile_
,
auto
&
d_tile_
)
{
// ar_br (areg, breg)
static_for
<
0
,
kWarpRepeatK_1
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_1
,
1
>
{}([
&
](
auto
i_m
)
{
constexpr
auto
beg_a
=
sequence
<
i_m
*
kSubBlockM_1
,
i_k
*
kSubBlockK_1
>
{};
constexpr
auto
end_a
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_1
,
(
i_k
+
1
)
*
kSubBlockK_1
>
{};
const
auto
w_a
=
get_slice_tile
(
a_tile_
,
beg_a
,
end_a
);
static_for
<
0
,
kWarpRepeatN_1
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_d
=
get_slice_tile
(
acc_d_
,
beg_acc
,
end_acc
);
auto
w_d
=
get_slice_tile
(
d_tile_
,
beg_b
,
end_b
);
warp_gemm_1
(
w_acc_d
,
w_a
,
w_d
);
set_slice_tile
(
acc_d_
,
w_acc_d
,
beg_acc
,
end_acc
);
});
});
});
};
// start of pipeline
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
clear_tile
(
acc_g
);
clear_tile
(
acc_u
);
constexpr
auto
k_per_block_0
=
Problem
::
FusedMoeTileShape
::
kK_a
;
const
index_t
loops_0
=
(
dim_size
+
k_per_block_0
-
1
)
/
k_per_block_0
;
index_t
i_0
=
0
;
while
(
i_0
<
(
loops_0
-
2
))
{
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
i_0
++
;
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
i_0
++
;
}
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
// prefetch
d_thread_type
d_tile
[
2
];
do_load_b1
(
d_tile
[
0
],
true_v
);
do_load_b1
(
d_tile
[
1
],
true_v
);
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
// redice acc_g/u
constexpr
auto
acc_spans_0
=
decltype
(
acc_g
)
::
get_distributed_spans
();
sweep_tile_span
(
acc_spans_0
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
acc_spans_0
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
element_wise
::
Silu
{}(
acc_g
(
i_j_idx
),
acc_g
(
i_j_idx
));
acc_g
(
i_j_idx
)
*=
acc_u
(
i_j_idx
);
});
});
constexpr
auto
n_per_block_1
=
Problem
::
FusedMoeTileShape
::
kN_d
;
const
index_t
loops_1
=
(
dim_size
+
n_per_block_1
-
1
)
/
n_per_block_1
;
const
auto
y
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
YDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
YDataType
>
(
acc_g
);
else
return
cast_tile
<
YDataType
>
(
acc_g
);
}();
auto
a_smem_ptr
=
reinterpret_cast
<
ADataType
*>
(
smem_ptr
)
+
a_smem_offset
;
auto
acc_d
=
MakeCBlockTile_Gemm1
<
Problem
>
();
clear_tile
(
acc_d
);
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
index_t
i_1
==
0
;
while
(
i_1
<
(
loops_1
-
2
))
{
// first buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
0
]);
do_load_b1
(
d_tile
[
0
],
true_v
);
i_1
++
;
// second buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
1
]);
do_load_b1
(
d_tile
[
1
],
true_v
);
i_1
++
;
}
auto
smem_0_window
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// first buffer
do_gemm_0
(
a_warp_windows_0
,
g_tile
[
0
],
g_tile
[
1
]);
i_0
++
;
async_load_tile
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})));
for
(
index_t
i_0
=
0
;
i_0
<
loops_0
;
i_0
++
)
{}
// second buffer
do_gemm_0
(
a_warp_windows_1
,
g_tile
[
1
],
g_tile
[
1
]);
i_0
++
;
}
template
<
typename
QDramBlockWindowTmp
,
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
View file @
199f7f71
...
...
@@ -117,14 +117,15 @@ struct FusedMoePipelinePolicy
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
if
constexpr
(
get_warp_size
()
<
=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"
not
not support thread has repeat along K yet"
);
static_assert
(
K_wav
<=
NumWarps
,
"
do
not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
...
...
@@ -150,14 +151,56 @@ struct FusedMoePipelinePolicy
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_
lan
,
M_
wav
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
M_rep
,
M_
wav
,
M_
lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
template
<
index_t
NPerBlock
,
...
...
@@ -166,12 +209,13 @@ struct FusedMoePipelinePolicy
index_t
WavesPerBlock_K
,
typename
WarpGemm
,
index_t
Alignment
,
FusedMoePermuteStyle
PermuteStyle
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
>
FusedMoeWeightPermuteEnum
PermuteStyle
=
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_MatrixCore_Swizzled
()
{
static_assert
(
Alignment
%
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKPerLane
==
0
);
if
constexpr
(
PermuteStyle
==
FusedMoePermute
Style
::
permute_b_nr_kr_kw_nw_kv
)
if
constexpr
(
PermuteStyle
==
FusedMoe
Weight
Permute
Enum
::
permute_b_nr_kr_kw_nw_kv
)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
...
...
@@ -218,20 +262,18 @@ struct FusedMoePipelinePolicy
Alignment
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
(
number
<
NSplits
>
=
{}
)
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermute
Style
::
permute_b_nr_kr_kw_nw_kv
)
if
constexpr
(
PermuteStype
==
FusedMoe
Weight
Permute
Enum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_u
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_G
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_G
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
...
...
@@ -242,20 +284,18 @@ struct FusedMoePipelinePolicy
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_U
()
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_U
(
number
<
NSplits
>
=
{}
)
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermute
Style
::
permute_b_nr_kr_kw_nw_kv
)
if
constexpr
(
PermuteStype
==
FusedMoe
Weight
Permute
Enum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_u
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_U
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_U
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
...
...
@@ -270,16 +310,14 @@ struct FusedMoePipelinePolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermute
Style
::
permute_b_nr_kr_kw_nw_kv
)
if
constexpr
(
PermuteStype
==
FusedMoe
Weight
Permute
Enum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_d
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_y
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm1BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm1BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm1
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_D
<
Problem
>
();
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_1
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm1
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_D
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
...
...
@@ -290,65 +328,12 @@ struct FusedMoePipelinePolicy
}
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
,
index_t
KPack
,
index_t
NumPrefetch
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSmemLoadTileDescriptor_SimpleMxK_Async
()
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kBlockSize
=
ck_tile
::
get_warp_size
()
*
NumWarps
;
// Problem::kBlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KVector
=
Alignment
;
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
KPerBlock
&&
warpSize
*
KVector
%
KPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
KPerBlock
/
KVector
;
// within a wave
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// within a wave
constexpr
index_t
NumIssues
=
MPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
MPerBlock
*
KPerBlock
/
(
kBlockSize
*
KVector
));
constexpr
index_t
BufferSize
=
NumIssues
*
NumWarps
*
(
warpSize
*
KVector
+
kPad
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumPrefetch
>
{},
// num_buffers
number
<
NumIssues
>
{},
// n0
number
<
NumWarps
>
{},
// n2
number
<
LaneGroups
>
{},
// n1
number
<
KPerBlock
/
KPack
>
{},
// k0
number
<
KPack
>
{}),
// k1
make_tuple
(
number
<
BufferSize
>
{},
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
number
<
warpSize
*
KVector
+
kPad
>
{},
number
<
KPerBlock
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
lds_block_desc
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumPrefetch
>
{},
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
3
,
2
>
{},
sequence
<
4
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStore
BlockDescriptor
_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStore
Desc
_A
()
{
// A async->LDS
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_
a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_
a
;
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBloc
kM_
0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBloc
kK_
0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
...
...
@@ -359,7 +344,7 @@ struct FusedMoePipelinePolicy
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>
warpSize
)
if
constexpr
(
LanesPerK
>
=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
...
...
@@ -433,7 +418,7 @@ struct FusedMoePipelinePolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
SmemLoadTileDistribution
_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LdsLoadDesc
_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
...
...
@@ -442,8 +427,8 @@ struct FusedMoePipelinePolicy
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_
a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_
a
;
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBloc
kM_
0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBloc
kK_
0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
...
...
@@ -454,12 +439,12 @@ struct FusedMoePipelinePolicy
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>
warpSize
)
if
constexpr
(
LanesPerK
>
=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
if
constexpr
(
wavesPerK
>
=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
...
...
@@ -526,96 +511,6 @@ struct FusedMoePipelinePolicy
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeASmemStoreTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignment_A
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
constexpr
index_t
NumPrefetch
=
Problem
::
Traits
::
NumPrefetchA
;
return
MakeSmemStoreBlockDescriptor_SimpleMxK_Async
<
kMperBlock
,
kKPerBlock
,
kBlockSize
,
NumWarps
,
KPack
,
Alignment
>
();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemStoreTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemStoreTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_U<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_D<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
...
...
@@ -640,72 +535,57 @@ struct FusedMoePipelinePolicy
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
#if 0
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr auto
Get
Gemm0()
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_
Gemm0
()
const
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::ADataType,
typename Problem::GDataType, // UDataType is the same
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::FusedMoeTileShape::kM_a,
Problem::FusedMoeTileShape::kN_g * 2,
Problem::FusedMoeTileShape::kK_a>>;
constexpr auto warp_gemm = []() {
return WarpGemmMfmaDispatcher<
typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}();
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_0
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_0
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_0
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr auto
Get
Gemm1()
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_
Gemm1
()
const
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::FusedMoeTileShape::kM_a,
Problem::FusedMoeTileShape::kN_d,
Problem::FusedMoeTileShape::kK_y>>;
constexpr auto warp_gemm = []() {
return WarpGemmMfmaDispatcher<
typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}();
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm1BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_1
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_1
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_1
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
#endif
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/kernel/fused_moe_kernel.hpp
View file @
199f7f71
...
...
@@ -5,18 +5,18 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
//
//
clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
...
...
@@ -25,12 +25,9 @@
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|-
// exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
//
...
...
@@ -55,8 +52,7 @@
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
...
...
@@ -73,7 +69,7 @@
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
//
//
clang-format on
//
namespace
ck_tile
{
...
...
@@ -81,81 +77,45 @@ namespace ck_tile {
template
<
typename
TilePartitioner_
,
typename
FusedMoePipeline_
,
typename
EpiloguePipeline_
>
struct
FusedMoeKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FusedMoePipeline
=
ck_tile
::
remove_cvref_t
<
FusedMoePipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FusedMoePipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FusedMoePipeline
::
kBlockPerCu
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FusedMoePipeline
=
remove_cvref_t
<
FusedMoePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
// TODO: not used
static
constexpr
index_t
kBlockSize
=
FusedMoePipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockPerCu
=
FusedMoePipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
static
constexpr
ck_tile
::
index_t
kBlockPerCuInput
=
FusedMoePipeline
::
Problem
::
kBlockPerCu
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
ADataType
>
;
using
GDataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
GDataType
>
;
using
UDataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
UDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
DDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
ODataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
AccDataType
>
;
using
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
ScaleDataType
>
;
using
DLayout
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
DLayout
>
;
using
FusedMoeTileShape
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
FusedMoeTileShape
>
;
static
constexpr
index_t
kBlockPerCuInput
=
FusedMoePipeline
::
Problem
::
kBlockPerCu
;
using
ADataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
ADataType
>
;
using
GDataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
GDataType
>
;
using
UDataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
UDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
DDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
ODataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
AccDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
FusedMoePipeline
::
ScaleDataType
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
FusedMoePipeline
::
FusedMoeTileShape
>
;
static
constexpr
bool
kPadDimSize
=
FusedMoePipeline
::
kPadDimSize
;
static
constexpr
bool
kPadHiddenSize
=
FusedMoePipeline
::
kPadHiddenSize
;
static
constexpr
bool
kPadSeqLenQ
=
FusedMoePipeline
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
FusedMoePipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FusedMoePipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FusedMoePipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FusedMoePipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FusedMoePipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FusedMoePipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FusedMoePipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FusedMoePipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FusedMoePipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadSeqLenK
)
n
+=
"sk"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
ADataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FusedMoePipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
DLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
}
template
<
ck_tile
::
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
template
<
index_t
I
>
// to avoid duplicated base class prblem, introduce an template
// arg
struct
FusedMoeEmptyKargs
{
};
...
...
@@ -180,25 +140,31 @@ struct FusedMoeKernel
// const void* num_tokens_post_padded_ptr;
const
void
*
num_sorted_tiles_ptr
;
ck_tile
::
index_t
dim_size
;
ck_tile
::
index_t
hidden_size
;
ck_tile
::
index_t
num_tokens
;
// input number of tokens for current iteration
ck_tile
::
index_t
num_experts
;
// number of groups
// ck_tile::index_t top_k; // need this?
ck_tile
::
index_t
stride_a
;
ck_tile
::
index_t
stride_g
;
ck_tile
::
index_t
stride_u
;
ck_tile
::
index_t
stride_d
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_g_expert
;
ck_tile
::
index_t
stride_u_expert
;
ck_tile
::
index_t
stride_d_expert
;
index_t
dim_size
;
index_t
hidden_size
;
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_a
;
index_t
stride_gu
;
// assume g/u have same stride
// index_t stride_u;
index_t
stride_d
;
index_t
stride_o
;
index_t
stride_expert_gu
;
// assume g/u have same stride
index_t
stride_expert_d
;
};
using
Kargs
=
FusedMoeCommonKargs
;
// std::conditional_t<kIsGroupMode, FusedMoeGroupModeKargs,
// FusedMoeBatchModeKargs>;
struct
FusedMoeMatrixCoreShuffleKargs
:
public
FusedMoeCommonKargs
{
// batch*nr_0*kr_0*waveflattern, now stride_kr is the stride in above
index_t
stride_gu_nr
;
index_t
stride_d_nr
;
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeMatrixCoreShuffleKargs
;
// host args are used inside host API
// and should be POD data structure
...
...
@@ -217,21 +183,21 @@ struct FusedMoeKernel
// const void* num_tokens_post_padded_ptr;
const
void
*
num_sorted_tiles_ptr
;
ck_tile
::
index_t
dim_size
;
ck_tile
::
index_t
hidden_size
;
ck_tile
::
index_t
num_tokens
;
// input number of tokens for current iteration
ck_tile
::
index_t
num_experts
;
// number of groups
//
ck_tile::
index_t top_k; // need this?
ck_tile
::
index_t
stride_a
;
ck_tile
::
index_t
stride_g
;
ck_tile
::
index_t
stride_u
;
ck_tile
::
index_t
stride_d
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_
g_
expert
;
ck_tile
::
index_t
stride_
u_
expert
;
ck_tile
::
index_t
stride_
d_
expert
;
index_t
dim_size
;
index_t
hidden_size
;
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_a
;
index_t
stride_g
;
index_t
stride_u
;
index_t
stride_d
;
index_t
stride_o
;
index_t
stride_expert
_gu
;
index_t
stride_expert
_gu
;
index_t
stride_expert
_d
;
};
using
Hargs
=
FusedMoeCommonHargs
;
...
...
@@ -244,45 +210,53 @@ struct FusedMoeKernel
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
FusedMoePipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
FusedMoePipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
ck_tile
::
index_t
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
ck_tile
::
index_t
*>
(
kargs
.
num_sorted_tiles_ptr
));
ck_tile
::
index_t
tile_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
;);
index_t
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
index_t
*>
(
kargs
.
num_sorted_tiles_ptr
));
index_t
nr_0
=
kargs
.
hidden_size
/
FusedMoePipeline
::
kBlockNr_0
;
index_t
kr_0
=
kargs
.
dim_size
/
FusedMoePipeline
::
kBlockKr_0
;
index_t
nr_1
=
kargs
.
dim_size
/
FusedMoePipeline
::
kBlockNr_1
;
index_t
kr_1
=
kargs
.
hidden_size
/
FusedMoePipeline
::
kBlockKr_1
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem_0
[
FusedMoePipeline
::
GetSmemSizeSingleBuffer
()];
__shared__
CK_TILE_LDS_ADDR
ADataType
smem_1
[
FusedMoePipeline
::
GetSmemSizeSingleBuffer
()];
// persistent loop
while
(
true
)
//
while(true)
{
const
auto
[
sorted_tile_id
,
hidden_tile_id
]
=
TilePartitioner
{}(
tile_id
,
num_sorted_tiles
,
kargs
.
hidden_size
);
TilePartitioner
{}(
num_sorted_tiles
,
kargs
.
hidden_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
ck_tile
::
index_t
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along hidden_size
ck_tile
::
index_t
hidden_id
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
FusedMoeTileShape
::
kN_g
);
index_t
hidden_id
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
FusedMoeTileShape
::
kBlockN_0
);
index_t
hidden_id_nr
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
block_nr
);
const
auto
a_coord
=
FusedMoePipeline
::
GetAIndex
();
// 2d thread offset, [i_row, i_col]
const
auto
token_
coor
d
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
FusedMoeTileShape
::
kM_
a
;
const
auto
sorted_
token_
i
d
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
FusedMoeTileShape
::
kBloc
kM_
0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
token_
coor
d
];
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_
token_
i
d
];
ScaleDataType
scale
=
reinterpret_cast
<
const
ScaleDataType
*>
(
kargs
.
sorted_weight_ptr
)[
token_
coor
d
];
reinterpret_cast
<
const
ScaleDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_
token_
i
d
];
const
auto
a_gtile_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
...
...
@@ -299,116 +273,101 @@ struct FusedMoeKernel
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_gtile_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kM_a
>
{},
number
<
FmhaPipeline
::
kK_a
>
{}),
{
0
,
0
});
const
auto
a_gtile_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kBlockM_0
>
{},
number
<
FusedMoePipeline
::
kBlockK_0
>
{}),
{
0
,
0
});
return
a_gtile_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_gtile_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_
g_
expert
+
hidden_id
*
kargs
.
stride_g
;
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_expert
_gu
+
hidden_id
_nr
*
kargs
.
stride_g
u_nr
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
kargs
.
hidden_size
,
kargs
.
dim_size
),
make_tuple
(
kargs
.
stride_g
,
1
),
make_tuple
(
nr_0
,
kr_0
,
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}
),
make_tuple
(
stride_g
u_nr
,
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}
,
1
),
number
<
FusedMoePipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
FusedMoeShape
::
kN_g
>
{},
number
<
FusedMoeShape
::
kK_a
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
>
{});
const
auto
g_gtile_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kN_g
>
{},
number
<
FmhaPipeline
::
kK_a
>
{}),
{
0
,
0
});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
FusedMoePipeline
::
kBlockNr_0
>
{},
number
<
FusedMoePipeline
::
kBlockKr_0
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
,
0
>
{});
const
auto
g_gtile_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kBlockNr_0
>
{},
number
<
FusedMoePipeline
::
kBlockKr_0
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
{
0
,
0
,
0
});
return
g_gtile_window_
;
}();
const
auto
u_gtile_window
=
[
&
]()
{
const
UDataType
*
u_ptr
=
reinterpret_cast
<
const
UDataType
*>
(
kargs
.
u_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_
u_
expert
+
hidden_id
*
kargs
.
stride_
u
;
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_expert
_gu
+
hidden_id
_nr
*
kargs
.
stride_
gu_nr
;
const
auto
u_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
u_ptr
,
make_tuple
(
kargs
.
hidden_size
,
kargs
.
dim_size
),
make_tuple
(
kargs
.
stride_
u
,
1
),
make_tuple
(
nr_0
,
kr_0
,
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}
),
make_tuple
(
stride_
gu_nr
,
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}
,
1
),
number
<
FusedMoePipeline
::
kAlignmentU
>
{},
number
<
1
>
{});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
make_tuple
(
number
<
FusedMoeShape
::
kN_u
>
{},
number
<
FusedMoeShape
::
kK_a
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
>
{});
const
auto
u_gtile_window_
=
make_tile_window
(
u_view_1_
,
make_tuple
(
number
<
FusedMoeShape
::
kN_u
>
{},
number
<
FusedMoeShape
::
kK_a
>
{}),
{
0
,
0
});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
make_tuple
(
number
<
FusedMoePipeline
::
kBlockNr_0
>
{},
number
<
FusedMoePipeline
::
kBlockKr_0
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
,
0
>
{});
const
auto
u_gtile_window_
=
make_tile_window
(
u_view_1_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kBlockNr_0
>
{},
number
<
FusedMoePipeline
::
kBlockKr_0
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
{
0
,
0
,
0
});
return
u_gtile_window_
;
}();
const
auto
d_gtile_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_d_expert
+
hidden_id
*
kargs
.
stride_d
;
}
else
{
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_d_expert
+
hidden_id
;
}
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
kargs
.
stride_expert_d
+
hidden_id_nr
*
kargs
.
stride_d_nr
;
}();
if
constexpr
(
std
::
is_same_v
<
DLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
hidden_size
,
kargs
.
dim_size
),
make_tuple
(
kargs
.
stride_d
,
1
),
number
<
FusedMoePipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
FusedMoeShape
::
kK_y
>
{},
number
<
FusedMoeShape
::
kN_d
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
>
{});
const
auto
d_gtile_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
FusedMoeShape
::
kK_y
>
{},
number
<
FusedMoeShape
::
kN_d
>
{}),
{
0
,
0
});
return
d_gtile_window_
;
}
else
{
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
kargs
.
dim_size
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_d
,
1
),
number
<
FusedMoePipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
FusedMoeShape
::
kN_d
>
{},
number
<
FusedMoeShape
::
kK_y
>
{}),
sequence
<
kPadHiddenSize
,
kPadDimSize
>
{});
const
auto
d_gtile_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
FusedMoeShape
::
kN_d
>
{},
number
<
FusedMoeShape
::
kK_y
>
{}),
{
0
,
0
});
return
d_gtile_window_
;
}
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
FusedMoePipeline
::
kBlockWaveFlatten
),
make_tuple
(
kargs
.
stride_d_nr
,
FusedMoePipeline
::
kBlockWaveFlatten
,
1
),
number
<
FusedMoePipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
FusedMoePipeline
::
kBlockNr_1
>
{},
number
<
FusedMoePipeline
::
kBlockKr_1
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
sequence
<
kPadDimSize
,
kPadHiddenSize
,
0
>
{});
const
auto
d_gtile_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
FusedMoePipeline
::
kBlockNr_1
>
{},
number
<
FusedMoePipeline
::
kBlockKr_1
>
{},
number
<
FusedMoePipeline
::
kBlockWaveFlatten
>
{}),
{
0
,
0
,
0
});
return
d_gtile_window_
;
}();
auto
o_gtile_window
=
[
&
]()
{
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
);
const
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
dim_size
),
make_tuple
(
kargs
.
stride_o
,
1
),
...
...
@@ -423,10 +382,11 @@ struct FusedMoeKernel
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
o_gtile_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kM_a
>
{},
number
<
FmhaPipeline
::
kK_a
>
{}),
{
0
,
0
});
const
auto
o_gtile_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
FusedMoeTileShape
::
kBlockM_0
>
{},
number
<
FusedMoePipeline
::
kBlockN_1
>
{}),
{
0
,
0
});
return
o_gtile_window_
;
}();
...
...
@@ -436,9 +396,13 @@ struct FusedMoeKernel
u_gtile_window
,
d_gtile_window
,
o_gtile_window
,
scale
);
tile_id
+=
gridDim
.
x
;
scale
,
smem_0
,
smem_1
,
kargs
.
dim_size
,
kargs
.
hidden_size
);
// tile_id += gridDim.x;
// epilogue not used
}
}
};
...
...
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_tile_partitioner.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/kernel/fused_moe_tile_partitioner.hpp
View file @
199f7f71
...
...
@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD
}
};
template
<
typename
FusedMoeTileShape_
>
struct
FusedMoeTilePartitioner_Linear
{
using
Shape
=
ck_tile
::
remove_cvref_t
<
FusedMoeTileShape_
>
;
static
constexpr
const
char
*
name
=
"2d"
;
// expert x hidden
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*hidden_size*/
))
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
// persistent
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
hidden_size
)
{
// TODO: this may need tuning
index_t
grids
=
num_cu
*
blocks_per_cu
;
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
Shape
::
kBlockM_0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
Shape
::
kBlockN_0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
/*
This pipeline split the gemm-n of B matrix for less register pressure
(assume B matrix is much larger than A)
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoePipelineNSplit2Policy
>
struct
FusedMoePipelineNSplit2
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
GDataType
=
remove_cvref_t
<
typename
Problem
::
GDataType
>
;
using
UDataType
=
remove_cvref_t
<
typename
Problem
::
UDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
ScaleDataType
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockNSub_0
=
FusedMoeTileShape
::
kBlockNSub_0
;
static
constexpr
index_t
kBlockM_0
=
FusedMoeTileShape
::
kBlockM_0
;
static
constexpr
index_t
kBlockN_0
=
FusedMoeTileShape
::
kBlockN_0
;
static
constexpr
index_t
kBlockK_0
=
FusedMoeTileShape
::
kBlockK_0
;
static
constexpr
index_t
kWarpM_0
=
FusedMoeTileShape
::
kWarpM_0
;
static
constexpr
index_t
kWarpN_0
=
FusedMoeTileShape
::
kWarpN_0
;
static
constexpr
index_t
kWarpK_0
=
FusedMoeTileShape
::
kWarpK_0
;
static
constexpr
index_t
kBlockWarpsM_0
=
FusedMoeTileShape
::
kBlockWarpsM_0
;
static
constexpr
index_t
kBlockWarpsN_0
=
FusedMoeTileShape
::
kBlockWarpsN_0
;
static
constexpr
index_t
kBlockWarpsK_0
=
FusedMoeTileShape
::
kBlockWarpsK_0
;
static
constexpr
index_t
kSubBlockM_0
=
FusedMoeTileShape
::
kSubBlockM_0
;
static
constexpr
index_t
kSubBlockN_0
=
FusedMoeTileShape
::
kSubBlockN_0
;
static
constexpr
index_t
kSubBlockK_0
=
FusedMoeTileShape
::
kSubBlockK_0
;
static
constexpr
index_t
kWarpRepeatM_0
=
FusedMoeTileShape
::
kWarpRepeatM_0
;
static
constexpr
index_t
kWarpRepeatN_0
=
FusedMoeTileShape
::
kWarpRepeatN_0
;
static
constexpr
index_t
kWarpRepeatK_0
=
FusedMoeTileShape
::
kWarpRepeatK_0
;
static_assert
(
kBlockN_0
==
2
*
kBlockNSub_0
);
// this pipeline only support split2
static_assert
(
kWarpRepeatN_0
%
2
==
0
);
static
constexpr
index_t
kBlockM_1
=
FusedMoeTileShape
::
kBlockM_1
;
static
constexpr
index_t
kBlockN_1
=
FusedMoeTileShape
::
kBlockN_1
;
static
constexpr
index_t
kBlockK_1
=
FusedMoeTileShape
::
kBlockK_1
;
static
constexpr
index_t
kWarpM_1
=
FusedMoeTileShape
::
kWarpM_1
;
static
constexpr
index_t
kWarpN_1
=
FusedMoeTileShape
::
kWarpN_1
;
static
constexpr
index_t
kWarpK_1
=
FusedMoeTileShape
::
kWarpK_1
;
static
constexpr
index_t
kBlockWarpsM_1
=
FusedMoeTileShape
::
kBlockWarpsM_1
;
static
constexpr
index_t
kBlockWarpsN_1
=
FusedMoeTileShape
::
kBlockWarpsN_1
;
static
constexpr
index_t
kBlockWarpsK_1
=
FusedMoeTileShape
::
kBlockWarpsK_1
;
static
constexpr
index_t
kSubBlockM_1
=
FusedMoeTileShape
::
kSubBlockM_1
;
static
constexpr
index_t
kSubBlockN_1
=
FusedMoeTileShape
::
kSubBlockN_1
;
static
constexpr
index_t
kSubBlockK_1
=
FusedMoeTileShape
::
kSubBlockK_1
;
static
constexpr
index_t
kWarpRepeatM_1
=
FusedMoeTileShape
::
kWarpRepeatM_1
;
static
constexpr
index_t
kWarpRepeatN_1
=
FusedMoeTileShape
::
kWarpRepeatN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
FusedMoeTileShape
::
kWarpRepeatK_1
;
using
MBlockType_0
=
decltype
(
Policy
::
GetMatrixCoreSwizzledBlockTIle_0
<
Problem
>
());
static
constexpr
index_t
kBlockNr_0
=
MBlockType_0
{}
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kBlockKr_0
=
MBlockType_0
{}
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kBlockWaveFlatten
=
MBlockType_0
{}
::
at
(
number
<
2
>
{});
static_assert
(
kBlockNr_0
%
2
==
0
);
static
constexpr
index_t
kBlockSubNr_0
=
kBlockNr_0
/
2
;
using
MBlockType_1
=
decltype
(
Policy
::
GetMatrixCoreSwizzledBlockTIle_1
<
Problem
>
());
static
constexpr
index_t
kBlockNr_1
=
MBlockType_1
{}
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kBlockKr_1
=
MBlockType_1
{}
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kBlockSubKr_1
=
kBlockKr_1
/
2
;
static_assert
(
kBlockSubNr_0
==
kBlockSubKr_1
);
static
constexpr
index_t
kAlignmentA
=
Policy
::
GetAlignment_A
<
Problem
>
();
static
constexpr
index_t
kAlignmentG
=
Policy
::
GetAlignment_G
<
Problem
>
();
static
constexpr
index_t
kAlignmentU
=
Policy
::
GetAlignment_U
<
Problem
>
();
static
constexpr
index_t
kAlignmentD
=
Policy
::
GetAlignment_D
<
Problem
>
();
static
constexpr
index_t
kAlignmentO
=
Policy
::
GetAlignment_O
<
Problem
>
();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"fused_moe_ns2"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSingleBuffer
()
{
return
Policy
<
Problem
>::
GetSmemSizeSingleBuffer
();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetAIndex
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOIndex
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
template
<
typename
AGlobalTensorView
,
typename
GGlobalTileWindow
,
typename
UGlobalTileWindow
,
typename
DGlobalTileWindow
,
typename
OGlobalTensorView
>
CK_TILE_DEVICE
auto
operator
()(
const
AGlobalTensorView
&
a_gtile_window_tmp
,
const
GGlobalTileWindow
&
g_gtile_window_tmp
,
const
UGlobalTileWindow
&
u_gtile_window_tmp
,
const
DGlobalTileWindow
&
d_gtile_window_tmp
,
OGlobalTensorView
&
o_gtile_window_tmp
,
// const void * sorted_weight_ptr,
ScaleDataType
scale
,
CK_TILE_LDS_ADDR
void
*
smem_0
,
CK_TILE_LDS_ADDR
void
*
smem_1
,
index_t
dim_size
,
index_t
/*hidden_size*/
)
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
auto
a_win
=
make_tile_window
(
a_gtile_window_tmp
.
get_bottom_tensor_view
(),
a_gtile_window_tmp
.
get_window_lengths
(),
a_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_win
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_tile_window
(
g_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kBlockSubNr_0
>
{},
number
<
kBlockKr_0
>
{},
number
<
kBlockWaveFlatten
>
{}),
{
number
<
kBlockSubNr_0
*
i
>
{},
I0
,
I0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
},
number
<
2
>
{});
auto
u_win
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_tile_window
(
u_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kBlockSubNr_0
>
{},
number
<
kBlockKr_0
>
{},
number
<
kBlockWaveFlatten
>
{}),
{
number
<
kBlockSubNr_0
*
i
>
{},
I0
,
I0
},
Policy
::
template
MakeGlobalTileDistribution_U
<
Problem
>());
},
number
<
2
>
{});
auto
d_win
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kBlockNr_1
>
{},
number
<
kBlockSubKr_1
>
{},
number
<
kBlockWaveFlatten
>
{}),
{
I0
,
number
<
kBlockSubKr_1
*
i
>
{},
I0
},
Policy
::
template
MakeGlobalTileDistribution_U
<
Problem
>());
},
number
<
2
>
{});
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
d_gtile_window_tmp
.
get_window_lengths
(),
d_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
o_win
=
make_tile_window
(
o_gtile_window_tmp
.
get_bottom_tensor_view
(),
o_gtile_window_tmp
.
get_window_lengths
(),
o_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_win
[
I0
]));
using
u_thread_type
=
decltype
(
load_tile
(
u_win
[
I0
]));
using
d_thread_type
=
decltype
(
load_tile
(
d_win
[
I0
]));
const
index_t
loops_0
=
(
dim_size
+
kBlockK_0
-
1
)
/
kBlockK_0
;
const
index_t
loops_1
=
(
dim_size
+
kBlockN_1
-
1
)
/
kBlockN_1
;
// issues_warps_lanes
auto
a_st0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// issues_warps_lanes
auto
a_st1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_ld0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// m*k
auto
a_ld1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
statically_indexed_array
<
g_thread_type
,
2
>
g_tls
;
statically_indexed_array
<
u_thread_type
,
2
>
u_tls
;
using
WarpGemm0
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm1
=
Policy
::
GetWarpGemm1
<
Problem
>
();
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// TODO: N fist, M next
// create and pre-cache a_reg warp-window
auto
make_a_warp_windows
=
[
&
](
auto
a_sld_
)
{
const
index_t
i_mwarp_0
=
get_warp_id
()
/
kBlockWarpsN_0
;
// construct A-warp-window
auto
warp_window
=
make_tile_window
(
a_sld_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm0
::
kM
>
{},
number
<
WarpGemm0
::
kK
>
{}),
a_sld_
.
get_window_origin
()
+
multi_index
<
2
>
{
i_mwarp_0
*
WarpGemm0
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm0
::
AWarpDstrEncoding
{}));
return
warp_window
;
};
auto
a_warp_windows_0
=
make_a_warp_windows
(
a_ld0
);
auto
a_warp_windows_1
=
make_a_warp_windows
(
a_ld1
);
auto
load_a
=
[
&
](
auto
&
a_store_
)
{
async_load_tile
(
a_store_
,
a_win
);
move_tile_window
(
a_win
,
{
number
<
0
>
{},
number
<
kBlockK_0
>
{}});
};
auto
load_n
=
[
&
](
auto
&
g_tile_
,
auto
&
u_tile_
,
auto
&
g_window_
,
auto
&
u_window_
)
{
g_tile_
=
load_tile
(
g_window_
);
u_tile_
=
load_tile
(
u_window_
);
move_tile_window
(
g_window_
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
move_tile_window
(
u_window_
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
};
auto
load_d
=
[
&
](
auto
&
d_tile_
)
{
d_tile_
=
load_tile
(
d_win
);
move_tile_window
(
d_win
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
};
auto
acc_g
=
generate_tuple
([
&
](
auto
)
{
MakeCBlockTile_Gemm0
<
Problem
>
();
},
number
<
2
>
{});
auto
acc_u
=
generate_tuple
([
&
](
auto
)
{
MakeCBlockTile_Gemm0
<
Problem
>
();
},
number
<
2
>
{});
// Note this function only do gemm of single Nsplit
// clang-format off
auto
gemm_0
=
[
&
](
auto
&
acc_g_
,
auto
&
acc_u_
,
auto
&
a_
,
auto
&
g_
,
auto
&
u_
)
{
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m
)
{
constexpr
auto
beg_a
=
sequence
<
i_m
*
kSubBlockM_0
,
i_k
*
kSubBlockK_0
>
{};
constexpr
auto
end_a
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_k
+
1
)
*
kSubBlockK_0
>
{};
auto
w_a
=
get_slice_tile
(
a_
,
beg_a
,
end_a
);
static_for
<
0
,
kWarpRepeatN_0
/
2
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
constexpr
auto
beg_b
=
sequence
<
i_n
*
kSubBlockN_0
,
i_k
*
kSubBlockK_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_n
+
1
)
*
kSubBlockN_0
,
(
i_k
+
1
)
*
kSubBlockK_0
,
0
>
{};
auto
w_acc_g
=
get_slice_tile
(
acc_g_
,
beg_acc
,
end_acc
);
auto
w_acc_u
=
get_slice_tile
(
acc_u_
,
beg_acc
,
end_acc
);
auto
w_g
=
get_slice_tile
(
g_
,
beg_b
,
end_b
);
auto
w_u
=
get_slice_tile
(
u_
,
beg_b
,
end_b
);
warp_gemm_0
(
w_acc_g
,
w_a
,
w_g
);
warp_gemm_0
(
w_acc_u
,
w_a
,
w_u
);
set_slice_tile
(
acc_g_
,
w_acc_g
,
beg_acc
,
end_acc
);
set_slice_tile
(
acc_u_
,
w_acc_u
,
beg_acc
,
end_acc
);
});
});
});
};
// clang-format on
// clang-format off
auto
gemm_1
=
[
&
](
auto
&
acc_d_
,
auto
&
y_
,
auto
&
d_
)
{
static_for
<
0
,
kWarpRepeatK_1
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_1
,
1
>
{}([
&
](
auto
i_m
)
{
constexpr
auto
beg_a
=
sequence
<
i_m
*
kSubBlockM_1
,
i_k
*
kSubBlockK_1
>
{};
constexpr
auto
end_a
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_1
,
(
i_k
+
1
)
*
kSubBlockK_1
>
{};
const
auto
w_y
=
get_slice_tile
(
y_
,
beg_a
,
end_a
);
static_for
<
0
,
kWarpRepeatN_1
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_1
,
i_n
*
kSubBlockN_1
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_1
,
(
i_n
+
1
)
*
kSubBlockN_1
>
{};
constexpr
auto
beg_d
=
sequence
<
i_n
*
kSubBlockN_1
,
i_k
*
kSubBlockK_1
,
0
>
{};
constexpr
auto
end_d
=
sequence
<
(
i_n
+
1
)
*
kSubBlockN_1
,
(
i_k
+
1
)
*
kSubBlockK_1
,
0
>
{};
auto
w_acc_d
=
get_slice_tile
(
acc_d_
,
beg_acc
,
end_acc
);
auto
w_d
=
get_slice_tile
(
d_
,
beg_d
,
end_d
);
warp_gemm_1
(
w_acc_d
,
w_y
,
w_d
);
set_slice_tile
(
acc_d_
,
w_acc_d
,
beg_acc
,
end_acc
);
});
});
});
};
// clang-format on
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
[
I0
].
get_num_of_access
()
>
{};
constexpr
auto
issues_u
=
number
<
u_win
[
I0
].
get_num_of_access
()
>
{};
constexpr
auto
issues_b
=
issues_g
+
issues_u
;
constexpr
auto
issues_d
=
number
<
d_win
[
I0
].
get_num_of_access
()
>
{};
constexpr
auto
issues_o
=
number
<
o_win
.
get_num_of_access
()
>
{};
// start of pipeline
// clang-format off
load_a
(
a_st0
);
load_n
(
g_tls
[
I0
],
u_tls
[
I0
],
g_win
[
I0
],
u_win
[
I0
]);
load_n
(
g_tls
[
I1
],
u_tls
[
I1
],
g_win
[
I1
],
u_win
[
I1
]);
load_a
(
a_st1
);
clear_tile
(
acc_g
[
I0
]);
clear_tile
(
acc_g
[
I1
]);
clear_tile
(
acc_u
[
I0
]);
clear_tile
(
acc_u
[
I1
]);
auto
a_reg
=
decltype
(
load_tile
(
a_warp_windows_0
)){};
index_t
i_0
=
0
;
while
(
i_0
<
(
loops_0
-
2
))
{
// first buffer
buffer_load_fence
(
issues_b
+
issues_b
+
issues_a
);
wave_barrier
();
a_reg
=
load_tile
(
a_warp_windows_0
);
buffer_load_fence
(
issues_b
+
issues_a
);
gemm_0
(
acc_g
[
I0
],
acc_u
[
I0
],
a_reg
,
g_tls
[
I0
],
u_tls
[
I0
]);
load_n
(
g_tls
[
I0
],
u_tls
[
I0
],
g_win
[
I0
],
u_win
[
I0
]);
buffer_load_fence
(
issues_b
+
issues_a
);
gemm_0
(
acc_g
[
I1
],
acc_u
[
I1
],
a_reg
,
g_tls
[
I1
],
u_tls
[
I1
]);
load_n
(
g_tls
[
I1
],
u_tls
[
I1
],
g_win
[
I1
],
u_win
[
I1
]);
load_a
(
a_st0
);
i_0
++
;
// second buffer
buffer_load_fence
(
issues_b
+
issues_b
+
issues_a
);
wave_barrier
();
a_reg
=
load_tile
(
a_warp_windows_1
);
buffer_load_fence
(
issues_b
+
issues_a
);
gemm_0
(
acc_g
[
I0
],
acc_u
[
I0
],
a_reg
,
g_tls
[
I0
],
u_tls
[
I0
]);
load_n
(
g_tls
[
I0
],
u_tls
[
I0
],
g_win
[
I0
],
u_win
[
I0
]);
buffer_load_fence
(
issues_b
+
issues_a
);
gemm_0
(
acc_g
[
I1
],
acc_u
[
I1
],
a_reg
,
g_tls
[
I1
],
u_tls
[
I1
]);
load_n
(
g_tls
[
I1
],
u_tls
[
I1
],
g_win
[
I1
],
u_win
[
I1
]);
load_a
(
a_st1
);
i_0
++
;
}
// first buffer
buffer_load_fence
(
issues_b
+
issues_b
+
issues_a
);
wave_barrier
();
a_reg
=
load_tile
(
a_warp_windows_0
);
gemm_0
(
acc_g
[
I0
],
acc_u
[
I0
],
a_reg
,
g_tls
[
I0
],
u_tls
[
I0
]);
load_n
(
g_tls
[
I0
],
u_tls
[
I0
],
g_win
[
I0
],
u_win
[
I0
]);
buffer_load_fence
(
issues_b
+
issues_a
);
gemm_0
(
acc_g
[
I1
],
acc_u
[
I1
],
a_reg
,
g_tls
[
I1
],
u_tls
[
I1
]);
load_n
(
g_tls
[
I1
],
u_tls
[
I1
],
g_win
[
I1
],
u_win
[
I1
]);
// second buffer
buffer_load_fence
(
issues_b
+
issues_b
);
wave_barrier
();
a_reg
=
load_tile
(
a_warp_windows_1
);
buffer_load_fence
(
issues_b
);
gemm_0
(
acc_g
[
I0
],
acc_u
[
I0
],
a_reg
,
g_tls
[
I0
],
u_tls
[
I0
]);
// prefetch
statically_indexed_array
<
d_thread_type
,
2
>
d_tls
;
load_d
(
d_tls
[
0
]);
load_d
(
d_tls
[
1
]);
buffer_load_fence
(
issues_d
+
issues_d
);
gemm_0
(
acc_g
[
I1
],
acc_u
[
I1
],
a_reg
,
g_tls
[
I1
],
u_tls
[
I1
]);
// redice acc_g/u
constexpr
auto
acc_spans_0
=
decltype
(
acc_g
)
::
get_distributed_spans
();
sweep_tile_span
(
acc_spans_0
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
acc_spans_0
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
element_wise
::
Silu
{}(
acc_g
[
I0
](
i_j_idx
),
acc_g
[
I0
](
i_j_idx
));
element_wise
::
Silu
{}(
acc_g
[
I1
](
i_j_idx
),
acc_g
[
I1
](
i_j_idx
));
acc_g
[
I0
](
i_j_idx
)
*=
acc_u
[
I0
](
i_j_idx
);
acc_g
[
I1
](
i_j_idx
)
*=
acc_u
[
I1
](
i_j_idx
);
});
});
const
auto
y_reg
=
generate_tuple
([
&
](
auto
i
)
{
if
constexpr
(
std
::
is_same_v
<
YDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
YDataType
>
(
acc_g
[
i
]);
else
return
cast_tile
<
YDataType
>
(
acc_g
[
i
]);
},
number
<
2
>
{});
auto
acc_d
=
MakeCBlockTile_Gemm1
<
Problem
>
();
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
// Second gemm
clear_tile
(
acc_d
);
// first buffer
buffer_load_fence
(
issues_d
);
gemm_1
(
acc_d
,
y_reg
[
I0
],
d_tls
[
I0
]);
load_d
(
d_tls
[
I0
]);
// second buffer
buffer_load_fence
(
issues_d
);
gemm_1
(
acc_d
,
y_reg
[
I1
],
d_tls
[
I1
]);
load_d
(
d_tls
[
I1
]);
update_tile
(
o_win
,
acc_d
);
index_t
i_1
=
0
;
while
(
i_1
<
(
loops_1
-
2
))
{
clear_tile
(
acc_d
);
// first buffer
buffer_load_fence
(
issues_d
+
issues_o
);
gemm_1
(
acc_d
,
y_reg
[
I0
],
d_tls
[
I0
]);
load_d
(
d_tls
[
I0
]);
// second buffer
buffer_load_fence
(
issues_d
+
issues_o
);
gemm_1
(
acc_d
,
y_reg
[
I1
],
d_tls
[
I1
]);
load_d
(
d_tls
[
I1
]);
update_tile
(
o_win
,
acc_d
);
i_1
++
;
}
clear_tile
(
acc_d
);
// first buffer
buffer_load_fence
(
issues_d
+
issues_o
);
gemm_1
(
acc_d
,
y_reg
[
I0
],
d_tls
[
I0
]);
// second buffer
buffer_load_fence
(
issues_o
);
gemm_1
(
acc_d
,
y_reg
[
I1
],
d_tls
[
I1
]);
update_tile
(
o_win
,
acc_d
);
// clang-format on
}
};
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
struct
FusedMoePipelineNSplit2Policy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO:
return
1
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
static
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_G
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_U
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
UDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_D
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_O
()
{
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
0
)
{
// pack fp16/bf16 atomic
static_assert
(
sizeof
(
typename
Problem
::
ODataType
)
==
2
);
return
2
;
}
else
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
1
)
{
// fp32 atomic
return
1
;
}
else
{
return
16
/
sizeof
(
typename
Problem
::
ODataType
);
}
}
template
<
typename
DataType_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
()
{
// TODO: this is for 3d layout
return
16
/
sizeof
(
remove_cvref_t
<
typename
Problem
::
DataType_
>
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Kw, Nw, Kv>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockTileNrKr()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Problem::FusedMoeTileShape::kBlockK_0 / Nw,
Problem::FusedMoeTileShape::kBlockK_0 / (Kw * Kv)>{};
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeSingleBuffer
()
{
constexpr
a_sld_desc
=
MakeLdsLoadDesc_A
<
Problem
>
();
constexpr
a_sst_desc
=
MakeLdsStoreDesc_A
<
Problem
>
();
static_assert
(
a_sld_desc
.
get_element_space_size
()
==
a_sst_desc
.
get_element_space_size
());
return
a_sld_desc
.
get_element_space_size
();
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
constexpr
index_t
K_vec
=
Alignment
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
// optimized version for async
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"do not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
template
<
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
WavesPerBlock_N
,
index_t
WavesPerBlock_K
,
typename
WarpGemm
,
index_t
Alignment
,
FusedMoeWeightPermuteEnum
PermuteStyle
=
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_MatrixCore_Swizzled
()
{
static_assert
(
Alignment
%
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKPerLane
==
0
);
if
constexpr
(
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
constexpr
index_t
Nr_p
=
WavesPerBlock_N
;
constexpr
index_t
Kr_p
=
WavesPerBlock_K
;
constexpr
index_t
Nr_y
=
Nr
/
Nr_p
;
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
2
>>
{});
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignment_A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kMPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
>
();
}
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
(
number
<
NSplits
>
=
{})
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_G
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_U
(
number
<
NSplits
>
=
{})
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_U
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_1
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm1
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_D
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockM_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockM_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
GDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
YDataType
,
typename
Problem
::
DDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm0
(
number
<
NSplits
>
=
{})
const
{
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_0
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_0
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_0
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
static_assert
(
WarpRepeatN
%
NSplits
==
0
);
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
/
NSplits
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm1
()
const
{
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_1
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_1
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_1
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/pipeline/fused_moe_pipeline_problem.hpp
View file @
199f7f71
...
...
@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem
static
constexpr
index_t
kBlockSize
=
FusedMoeTileShape
::
NumWarps
*
get_warp_size
();
// attributes from traits
// static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
// static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
// static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
// static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// static constexpr auto BiasEnum = Traits::BiasEnum;
// static constexpr bool kStoreLSE = Traits::kStoreLSE;
// static constexpr bool kHasDropout = Traits::kHasDropout;
// static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
using
GateActivation
=
remove_cvref_t
<
typename
Traits
::
GateActivation_
>
;
//
using GateActivation = remove_cvref_t<typename Traits::GateActivation_>;
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_tile_shape.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/pipeline/fused_moe_tile_shape.hpp
View file @
199f7f71
...
...
@@ -37,12 +37,11 @@ M_a | A | | | | | | | | |
SILU x-----x +----------+
K_y = N_g = N_u dim
*/
template
<
typename
BlockTile_
,
// sequence<M_a, N_g, N_
u
, K_a, N_d
template
<
typename
BlockTile_
,
// sequence<M_a, N_g, N_
sub0
, K_a, N_d
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
,
bool
IsDLayoutRowMajor_
>
typename
Gemm1WarpTile_
>
struct
FusedMoeTileShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
...
...
@@ -58,25 +57,60 @@ struct FusedMoeTileShape
static
constexpr
index_t
kM_a
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN_g
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN_u
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kK_a
=
BlockTile
::
at
(
number
<
3
>
{});
static
constexpr
index_t
kN_d
=
BlockTile
::
at
(
number
<
4
>
{});
static_assert
(
kN_g
==
kN_u
);
static
constexpr
index_t
kN_u
=
BlockTile
::
at
(
number
<
1
>
{});
// e.g. N_g = 256, n_sub_gu=128, then we split blockN of G/U into 2 parts to loopover
// this can help B matrix direct-to-register using too much vgpr issue
static
constexpr
index_t
kN_sub0
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kK_a
=
BlockTile
::
at
(
number
<
3
>
{});
static
constexpr
index_t
kN_d
=
BlockTile
::
at
(
number
<
4
>
{});
// static_assert(kN_g == kN_u);
static
constexpr
index_t
kK_y
=
kN_g
;
static
constexpr
index_t
kM_0
=
kM_a
;
static
constexpr
index_t
kN_0
=
kN_g
;
// note N will x2
static
constexpr
index_t
kK_0
=
kK_a
;
static
constexpr
index_t
kBlockNSub_0
=
kN_sub0
;
// allow partial
static
constexpr
index_t
kBlockM_0
=
kM_a
;
static
constexpr
index_t
kBlockN_0
=
kN_g
;
// note N will x2 in real pipeline for gemm-0
static
constexpr
index_t
kBlockK_0
=
kK_a
;
static
constexpr
index_t
kWarpM_0
=
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kWarpN_0
=
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kWarpK_0
=
Gemm0WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockWarpsM_0
=
Gemm0BlockWarps
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
kBlockWarpsN_0
=
Gemm0BlockWarps
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
kBlockWarpsK_0
=
Gemm0BlockWarps
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
kSubBlockM_0
=
kWarpM_0
*
kBlockWarpsM_0
;
static
constexpr
index_t
kSubBlockN_0
=
kWarpN_0
*
kBlockWarpsN_0
;
static
constexpr
index_t
kSubBlockK_0
=
kWarpK_0
*
kBlockWarpsK_0
;
static_assert
(
kBlockM_0
%
kSubBlockM_0
==
0
);
static_assert
(
kBlockN_0
%
kSubBlockN_0
==
0
);
static_assert
(
kBlockK_0
%
kSubBlockK_0
==
0
);
static
constexpr
index_t
kWarpRepeatM_0
=
kBlockM_0
/
kSubBlockM_0
;
static
constexpr
index_t
kWarpRepeatN_0
=
kBlockN_0
/
kSubBlockN_0
;
static
constexpr
index_t
kWarpRepeatK_0
=
kBlockK_0
/
kSubBlockK_0
;
static
constexpr
index_t
kM_1
=
kM_0
;
static
constexpr
index_t
kN_1
=
kN_d
;
static
constexpr
index_t
kK_1
=
kN_g
;
static
constexpr
index_t
kBlockKSub_1
=
kBlockNSub_0
;
static
constexpr
index_t
kBlockM_1
=
kM_a
;
static
constexpr
index_t
kBlockN_1
=
kN_d
;
static
constexpr
index_t
kBlockK_1
=
kN_g
;
static
constexpr
index_t
kWarpM_1
=
Gemm1WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kWarpN_1
=
Gemm1WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kWarpK_1
=
Gemm1WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockWarpsM_1
=
Gemm1BlockWarps
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
kBlockWarpsN_1
=
Gemm1BlockWarps
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
kBlockWarpsK_1
=
Gemm1BlockWarps
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
kSubBlockM_1
=
kWarpM_1
*
kBlockWarpsM_1
;
static
constexpr
index_t
kSubBlockN_1
=
kWarpN_1
*
kBlockWarpsN_1
;
static
constexpr
index_t
kSubBlockK_1
=
kWarpK_1
*
kBlockWarpsK_1
;
static_assert
(
kBlockM_1
%
kSubBlockM_1
==
0
);
static_assert
(
kBlockN_1
%
kSubBlockN_1
==
0
);
static_assert
(
kBlockK_1
%
kSubBlockK_1
==
0
);
static
constexpr
index_t
kWarpRepeatM_1
=
kBlockM_1
/
kSubBlockM_1
;
static
constexpr
index_t
kWarpRepeatN_1
=
kBlockN_1
/
kSubBlockN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
kBlockK_1
/
kSubBlockK_1
;
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
static
constexpr
bool
IsDLayoutRowMajor
=
IsDLayoutRowMajor_
;
using
DLayout
=
std
::
conditional_t
<
IsDLayoutRowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
//
static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_;
//
using DLayout = std::conditional_t<IsDLayoutRowMajor,
//
ck_tile::tensor_layout::gemm::RowMajor,
//
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
namespace
ck_tile
{
template
<
bool
DownPreShuffled_
=
false
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
,
index_t
OAtomic_
=
0
,
FusedMoeWeightPermuteEnum
WeightPermute_
=
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
>
struct
FusedMoeTraits
{
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
FusedMoeWeightPermuteEnum
WeightPermute
=
WeightPermute_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
// 0-pack fp16/bf16 atomic, 1-fp32 atomic
};
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
enum
class
FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
};
}
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
199f7f71
...
...
@@ -616,11 +616,51 @@ struct buffer_store_if<1>
}
};
CK_TILE_DEVICE
void
buffer_load_fence_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
index_t
cnt
>
CK_TILE_DEVICE
void
buffer_load_fence_raw
(
number
<
cnt
>
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
#if 0
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
// const index_t origin_cnt = 0x0f70;
__builtin_amdgcn_s_waitcnt(0x0f70 | cnt);
}
#endif
template
<
index_t
cnt
>
CK_TILE_DEVICE
void
buffer_load_fence
(
number
<
cnt
>
)
{
/*
simm16, simm16[3:0] -> bits[3:0], simm16[15:14] -> bits[5:4]
*/
static_assert
(
cnt
<
64
);
constexpr
index_t
low
=
cnt
&
0xf
;
constexpr
index_t
hi
=
(
cnt
&
0x30
)
<<
14
;
constexpr
index_t
c
=
0x0f70
|
low
|
hi
;
__builtin_amdgcn_s_waitcnt
(
c
);
}
CK_TILE_DEVICE
void
wave_barrier
()
{
__builtin_amdgcn_s_barrier
();
}
CK_TILE_DEVICE
auto
async_load_fence_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
index_t
cnt
>
CK_TILE_DEVICE
auto
async_load_fence
(
number
<
cnt
>
)
{
buffer_load_fence
(
number
<
cnt
>
{});
}
// clang-format off
namespace
impl
{
...
...
@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
}
// clang-format on
template
<
typename
...
T
>
CK_TILE_DEVICE
void
buffer_load_fence
(
index_t
cnt
=
0
,
T
&
...
o
)
CK_TILE_DEVICE
void
buffer_load_fence
_raw
(
index_t
cnt
=
0
,
T
&
...
o
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
impl
::
insert_dummy_dep
(
o
...);
}
CK_TILE_DEVICE
void
buffer_store_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
void
buffer_store_fence
_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
...
...
@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
...
...
@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
:
"memory"
);
}
#if 0
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
#endif
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
...
...
@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_async_buffer_load
(
CK_TILE_LDS_ADDR
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
,
index_t
flag
=
0
,
bool_constant
<
oob_conditional_check
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
if
constexpr
(
oob_conditional_check
)
{
index_t
v_offset
=
flag
?
v_offset
:
src_wave_buffer_resource
[
2
];
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
// reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof
(
uint32_t
),
v_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
{
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
// reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof
(
uint32_t
),
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
...
...
@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// buffer_store requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
...
...
@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
...
...
include/ck_tile/core/config.hpp
View file @
199f7f71
...
...
@@ -43,6 +43,20 @@
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
199f7f71
...
...
@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global,
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
pre_nop
=
false
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
199f7f71
...
...
@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
199f7f71
...
...
@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
}
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template
<
typename
X
,
typename
Y
>
struct
is_similiar_distributed_tensor
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
TypeX
,
typename
DistX
,
typename
TypeY
,
typename
DistY
>
struct
is_similiar_distributed_tensor
<
static_distributed_tensor
<
TypeX
,
DistX
>
,
static_distributed_tensor
<
TypeY
,
DistY
>>
{
using
Tx
=
static_distributed_tensor
<
TypeX
,
DistX
>
;
using
Ty
=
static_distributed_tensor
<
TypeY
,
DistY
>
;
static
constexpr
bool
value
=
std
::
is_same_v
<
typename
Tx
::
DataType
,
typename
Ty
::
DataType
>
&&
Tx
::
get_thread_buffer_size
()
==
Ty
::
get_thread_buffer_size
();
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_similiar_distributed_tensor_v
=
is_similiar_distributed_tensor
<
X
,
Y
>::
value
;
}
// namespace detail
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
199f7f71
...
...
@@ -104,6 +104,23 @@ struct tensor_view
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
199f7f71
...
...
@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution
});
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}));
constexpr
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
-
size_per_buf
;
constexpr
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
});
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
...
...
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
199f7f71
...
...
@@ -39,7 +39,7 @@ struct Default2DEpilogue
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
buffer_store_fence
_raw
();
}
else
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
View file @
199f7f71
...
...
@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
buffer_load_fence
_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
...
...
@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...
...
@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
...
...
@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
async_load_fence
_raw
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment