Commit 1ba8a08f authored by carlushuang's avatar carlushuang
Browse files

update tmp work

parent bf214665
...@@ -252,7 +252,7 @@ struct FusedMoeKernel ...@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; // __shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane( ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;); ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
...@@ -436,8 +436,7 @@ struct FusedMoeKernel ...@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window, u_gtile_window,
d_gtile_window, d_gtile_window,
o_gtile_window, o_gtile_window,
scale, scale);
smem_ptr);
tile_id += gridDim.x; tile_id += gridDim.x;
} }
......
...@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col // this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex() CK_TILE_HOST_DEVICE static auto GetAIndex()
{ {
constexpr auto a_dist = Policy::template MakeAGlobalTileDistribution<Problem>(); constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index(); const auto a_coord = a_dist.calculate_index();
return a_coord; return a_coord;
} }
...@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView& o_gtile_window_tmp, OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr, // const void * sorted_weight_ptr,
ScaleDataType scale, ScaleDataType scale,
void* smem_ptr, CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size, index_t dim_size,
index_t hidden_size) index_t hidden_size)
{ {
...@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(), a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(), a_gtile_window_tmp.get_window_origin(),
Policy::template MakeAGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_gtile_window = auto g_gtile_window =
make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
g_gtile_window_tmp.get_window_lengths(), g_gtile_window_tmp.get_window_lengths(),
g_gtile_window_tmp.get_window_origin(), g_gtile_window_tmp.get_window_origin(),
Policy::template MakeGGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
auto u_gtile_window = auto u_gtile_window =
make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
u_gtile_window_tmp.get_window_lengths(), u_gtile_window_tmp.get_window_lengths(),
u_gtile_window_tmp.get_window_origin(), u_gtile_window_tmp.get_window_origin(),
Policy::template MakeUGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_U<Problem>());
auto d_gtile_window = auto d_gtile_window =
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(), d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(), d_gtile_window_tmp.get_window_origin(),
Policy::template MakeDGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_D<Problem>());
auto o_gtile_window = auto o_gtile_window =
make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
...@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset; auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
make_tile_window(make_tensor_view<address_space_enum::lds>( auto smem_0_window = make_tile_window(
a_smem_ptr, Policy::template MakeALdsStoreBlockDescriptor<Problem>()), make_tensor_view<address_space_enum::lds>(
Policy::template MakeALdsStoreBlockDescriptor<Problem>().get_lengths(), smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()),
{0, 0}); Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), a_gtile_window); async_load_tile(k_lds_store(LdsSeq.at(number<0>{})));
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {} for(index_t i_0 = 0; i_0 < loops_0; i_0++) {}
} }
...@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
return o_acc; return o_acc;
...@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
...@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k0_loops <= 2) if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
async_load_fence(); async_load_fence_raw();
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
...@@ -14,7 +14,6 @@ namespace ck_tile { ...@@ -14,7 +14,6 @@ namespace ck_tile {
struct FusedMoePipelinePolicy struct FusedMoePipelinePolicy
{ {
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{ {
// TODO: // TODO:
...@@ -22,7 +21,7 @@ struct FusedMoePipelinePolicy ...@@ -22,7 +21,7 @@ struct FusedMoePipelinePolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{ {
// using async // using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords(); static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
...@@ -32,54 +31,27 @@ struct FusedMoePipelinePolicy ...@@ -32,54 +31,27 @@ struct FusedMoePipelinePolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentG() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{ {
static constexpr index_t copy_bytes = [&]() { static constexpr index_t copy_bytes = [&]() { return 16; }();
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType); static constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentU() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_U()
{ {
static constexpr index_t copy_bytes = [&]() { static constexpr index_t copy_bytes = [&]() { return 16; }();
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::UDataType); static constexpr index_t data_bytes = sizeof(typename Problem::UDataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentD() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{ {
static constexpr index_t copy_bytes = [&]() { static constexpr index_t copy_bytes = [&]() { return 16; }();
if constexpr(Problem::Traits::DownPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType); static constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
...@@ -93,29 +65,11 @@ struct FusedMoePipelinePolicy ...@@ -93,29 +65,11 @@ struct FusedMoePipelinePolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackA() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{ {
return GetSmemKPack<typename Problem::ADataType>(); return GetSmemKPack<typename Problem::ADataType>();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackG()
{
return GetSmemKPack<typename Problem::GDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackU()
{
return GetSmemKPack<typename Problem::UDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackD()
{
return GetSmemKPack<typename Problem::DDataType>();
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment> template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{ {
...@@ -206,102 +160,58 @@ struct FusedMoePipelinePolicy ...@@ -206,102 +160,58 @@ struct FusedMoePipelinePolicy
// Caution: this will require global memory pre-shuffled to follow the mfma layout // Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS // to maximize the L1/L2 channel while skip LDS
/* template <index_t NPerBlock,
index_t KPerBlock,
(b) n0 n1 n2 k0 k1 k2 index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
klanes typename WarpGemm,
| index_t Alignment,
nr 4 kr 4 16 8 FusedMoePermuteStyle PermuteStyle = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv>
(b) n0 n1 k0 k1 n2 k2 -> kthreads CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
| |
V V
waves nlanes
klanes
|
nr kr 4 4 16 8
(b) n0 k0 n1 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
*/
template <typename BlockTile, typename BlockWarps, typename WarpGemm, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled_NxK()
{ {
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0); static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
static_assert(BlockWarps{}.at(number<0>{}) == 1 && BlockWarps{}.at(number<2>{}) == 1);
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
constexpr index_t NPerBlock = BlockTile{}.at(number<1>{});
constexpr index_t KPerBlock = BlockTile{}.at(number<2>{});
constexpr index_t K2 = Alignment; if constexpr(PermuteStyle == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; {
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t N1 = NumWarps; constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
static_assert(NPerBlock % (N1 * N2) == 0); constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2); static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t N0 = NPerBlock / (N1 * N2); constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return make_static_tile_distribution( constexpr index_t Nr_p = WavesPerBlock_N;
tile_distribution_encoding<sequence<1>, constexpr index_t Kr_p = WavesPerBlock_K;
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>, constexpr index_t Nr_y = Nr / Nr_p;
tuple<sequence<1, 2>, sequence<2>>, constexpr index_t Kr_y = Kr / Kr_p;
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>, // 0
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>, // major 1 2 3
tuple<sequence<1, 2>, sequence<2>>, // minor 0 1 0 1 0 1 2
tuple<sequence<1, 0>, sequence<1>>, tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
sequence<1, 2>,
sequence<0, 2>>{}); // Nr_p, Kr_p Kw Nw
} tuple<sequence<1, 2>, sequence<3, 3>>,
else tuple<sequence<1, 1>, sequence<0, 1>>,
{
constexpr index_t K_lan = K_rem; // Nr_y Kr_y Kv
constexpr index_t M_lan = get_warp_size() / K_lan; sequence<1, 2, 3>,
constexpr index_t M_wav = NumWarps; sequence<0, 0, 2>>{});
static_assert(MPerBlock % (M_lan * M_wav) == 0, // clang-format on
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAGlobalTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{ {
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a; constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>(); constexpr index_t Alignment = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kMPerBlock, return MakeGlobalTileDistribution_SimpleMxK_Async<kMPerBlock,
kKPerBlock, kKPerBlock,
NumWarps, NumWarps,
...@@ -309,42 +219,75 @@ struct FusedMoePipelinePolicy ...@@ -309,42 +219,75 @@ struct FusedMoePipelinePolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGGlobalTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; {
constexpr index_t Alignment = GetAlignmentG<Problem>(); constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock, constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
kKPerBlock, constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {}
NumWarps, ::at(number<1>{});
Alignment>(); constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_G<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUGlobalTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_U()
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; {
constexpr index_t Alignment = GetAlignmentU<Problem>(); constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock, constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
kKPerBlock, constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {}
NumWarps, ::at(number<1>{});
Alignment>(); constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_U<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDGlobalTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y; if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; {
constexpr index_t Alignment = GetAlignmentD<Problem>(); constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock, constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
kKPerBlock, constexpr index_t WavesPerBlock_N = Problem::Gemm1BlockWarps {}
NumWarps, ::at(number<1>{});
Alignment>(); constexpr index_t WavesPerBlock_K = Problem::Gemm1BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm1<Problem>()>;
constexpr index_t Alignment = GetAlignment_D<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
} }
template <index_t MPerBlock, template <index_t MPerBlock,
...@@ -359,10 +302,8 @@ struct FusedMoePipelinePolicy ...@@ -359,10 +302,8 @@ struct FusedMoePipelinePolicy
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize; constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds constexpr index_t KVector = Alignment; // this is for global load
constexpr index_t KVector = constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
Alignment; // GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= KPerBlock && warpSize * KVector % KPerBlock == 0); static_assert(warpSize * KVector >= KPerBlock && warpSize * KVector % KPerBlock == 0);
constexpr index_t LanesPerK = KPerBlock / KVector; // within a wave constexpr index_t LanesPerK = KPerBlock / KVector; // within a wave
...@@ -402,77 +343,188 @@ struct FusedMoePipelinePolicy ...@@ -402,77 +343,188 @@ struct FusedMoePipelinePolicy
return lds_block_desc; return lds_block_desc;
} }
template <index_t MPerBlock, template <typename Problem>
index_t KPerBlock, CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor_A()
index_t NumWarps,
index_t KPack,
index_t Alignement,
index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeSmemStoreBlockDescriptor_SimpleMxK_Async(number<IBuf> = number<0>{})
{ {
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize; // A async->LDS
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds static_assert(kKPerBlock % kVector == 0);
// constexpr index_t Alignement = GetAlignmentK<Problem>(); // this is for global load constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
constexpr index_t kPad = if constexpr(LanesPerK > warpSize)
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed {
// need multiple waves to load K
static_assert(warpSize * Alignement >= KPerBlock && warpSize * Alignement % KPerBlock == 0); static_assert(LanesPerK % warpSize == 0);
constexpr index_t LanesPerK = constexpr index_t wavesPerK = LanesPerK / warpSize;
KPerBlock / Alignement; // how many lane (within a wave) to load K if constexpr(wavesPerK > NumWarps)
constexpr index_t LaneGroups = {
warpSize / // TODO: need multiple issues along K to load all data
LanesPerK; // how many groups (within a wave), they may load different N, but same K }
constexpr index_t NumIssues = MPerBlock / (LaneGroups * NumWarps); else
static_assert(NumIssues == MPerBlock * KPerBlock / (BlockSize * Alignement)); {
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( constexpr index_t NumIssues = kMPerBlock / wavesPerM;
make_tuple(number<NumIssues>{}, // n0 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
number<LaneGroups>{}, // n1 make_tuple(number<NumIssues>{}, // m0
number<NumWarps>{}, // n2 number<wavesPerM>{}, // m1
number<LanesPerK>{}, // k0 number<wavesPerK>{}, // k0
number<Alignement>{}), // k1 number<warpSize>{}, // k1
make_tuple(number<NumWarps*(warpSize * Alignement + kPad)>{}, number<KVector>{}), // k2
number<KPerBlock>{}, make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<warpSize * Alignement + kPad>{}, number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<Alignement>{}, number<warpSize * KVector + kPad>{}, // k0
number<1>{}), number<KVector>{}, // k1
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{}, number<1>{}), // k2
number<Alignement>{}, number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{}); number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1) lds_block_desc_0,
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( make_tuple(
k_lds_block_desc_0, make_pass_through_transform(number<NumIssues>{}),
make_tuple(make_pass_through_transform(number<NumIssues>{}), make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_pass_through_transform(number<NumWarps>{}), make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_merge_transform(make_tuple( make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
number<LaneGroups>{}, number<LanesPerK>{}, number<Alignement>{}))), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); return lds_block_desc_issues_warps_lanes;
}
return k_lds_block_desc_issues_warps_lanes; }
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemLoadTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeSmemLoadTileDistribution_A()
{ {
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a; // A async->LDS
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; // Note that, this descriptor is only to construct the layout inside LDS
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; // in real Gemm pipeline, ds_read may not follow this pattern
constexpr index_t Alignment = GetAlignmentA<Problem>(); // (may follow that in tile_distribution)
constexpr index_t KPack = GetSmemKPackA<Problem>(); // below code is almost the same as SmemStore dist, with difference:
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA; // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kMPerBlock, constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
kKPerBlock, constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
NumWarps, constexpr index_t kPad = KPack; // pad between warps
Alignment,
KPack, static_assert(kKPerBlock % kVector == 0);
NumPrefetch>(); constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK > warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = kMPerBlock / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemStoreTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeASmemStoreTileDistribution()
...@@ -480,8 +532,8 @@ struct FusedMoePipelinePolicy ...@@ -480,8 +532,8 @@ struct FusedMoePipelinePolicy
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a; constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>(); constexpr index_t Alignment = GetAlignment_A<Problem>();
constexpr index_t KPack = GetSmemKPackA<Problem>(); constexpr index_t KPack = GetSmemKPack_A<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA; constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemStoreBlockDescriptor_SimpleMxK_Async<kMperBlock, return MakeSmemStoreBlockDescriptor_SimpleMxK_Async<kMperBlock,
...@@ -492,13 +544,14 @@ struct FusedMoePipelinePolicy ...@@ -492,13 +544,14 @@ struct FusedMoePipelinePolicy
Alignment>(); Alignment>();
} }
#if 0
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>(); constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>(); constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG; constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
...@@ -515,7 +568,7 @@ struct FusedMoePipelinePolicy ...@@ -515,7 +568,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>(); constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>(); constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG; constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
...@@ -533,7 +586,7 @@ struct FusedMoePipelinePolicy ...@@ -533,7 +586,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentU<Problem>(); constexpr index_t Alignment = GetAlignment_U<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>(); constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU; constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
...@@ -551,7 +604,7 @@ struct FusedMoePipelinePolicy ...@@ -551,7 +604,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentD<Problem>(); constexpr index_t Alignment = GetAlignment_D<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>(); constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD; constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
...@@ -562,7 +615,32 @@ struct FusedMoePipelinePolicy ...@@ -562,7 +615,32 @@ struct FusedMoePipelinePolicy
KPack, KPack,
NumPrefetch>(); NumPrefetch>();
} }
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
return WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
return WarpGemmMfmaDispatcher<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
#if 0
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm0() CK_TILE_HOST_DEVICE static constexpr auto GetGemm0()
{ {
...@@ -628,5 +706,6 @@ struct FusedMoePipelinePolicy ...@@ -628,5 +706,6 @@ struct FusedMoePipelinePolicy
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
} }
#endif
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -12,21 +12,22 @@ ...@@ -12,21 +12,22 @@
namespace ck_tile { namespace ck_tile {
template <bool GateUpPreShuffled_ = false, enum class FusedMoePermuteStyle
bool DownPreShuffled_ = false, {
index_t NumPrefetchA_ = 2, // permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
index_t NumPrefetchG_ = 2, // permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
index_t NumPrefetchU_ = 2, permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
index_t NumPrefetchD_ = 2, permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> no_permute = 999,
};
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits struct FusedMoeTraits
{ {
static constexpr bool GateUpPreShuffled = GateUpPreShuffled_; static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr bool DownPreShuffled = DownPreShuffled_; static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t NumPrefetchA = NumPrefetchA_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr index_t NumPrefetchG = NumPrefetchG_;
static constexpr index_t NumPrefetchU = NumPrefetchU_;
static constexpr index_t NumPrefetchD = NumPrefetchD_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment