Commit cf646183 authored by carlushuang's avatar carlushuang
Browse files

compile OK

parent 70fa98ad
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/fused_moe.hpp"
#include <string> #include <string>
// this is only a convenient structure for creating an example // this is only a convenient structure for creating an example
...@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename ...@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
struct FusedMoeGemmTypeConfig; struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{ {
using ADataType = ck_tile::bf16_t; using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t; using GDataType = ck_tile::bf16_t;
...@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ...@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t,
}; };
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{ {
using ADataType = ck_tile::int8_t; using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t; using GDataType = ck_tile::int8_t;
...@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ...@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t,
}; };
// runtime args // runtime args
struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{ {
}; };
......
...@@ -3,33 +3,25 @@ ...@@ -3,33 +3,25 @@
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "fused_moegemm.hpp" #include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j` // Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_> template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a); float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s) float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{ {
template <ck_tile::index_t... Is> // clang-format off
using S = ck_tile::sequence<Is...>;
float r = -1; float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
ck_tile::bf16_t,
ck_tile::bf16_t,
float,
float,
float,
float,
S<32, 512, 128, 128>,
S<4, 1, 1>,
S<32, 32, 16>,
1,
0>;
fused_moegemm_<t_>(s, a); fused_moegemm_<t_>(s, a);
} }
// clang-format on
return r; return r;
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp" #include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp" #include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_> template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{ {
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>; using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0, using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0, typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType, typename Ts_::GDataType,
typename Ts_::DDataType, typename Ts_::DDataType,
typename Ts_::AccDataType, typename Ts_::AccDataType,
...@@ -25,9 +33,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -25,9 +33,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::YSmoothScaleDataType, typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType, typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, typename Ts_::IndexDataType,
ck_tile::Gelu, // TODO: hardcoded ck_tile::element_wise::Gelu, // TODO: hardcoded
f_shape, f_shape,
f_traits> f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
...@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal
{ {
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = remove_cvref_t<typename TypeConfig::ADataType>; using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = remove_cvref_t<typename TypeConfig::GDataType>; using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = remove_cvref_t<typename TypeConfig::DDataType>; using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = remove_cvref_t<typename TypeConfig::AccDataType>; using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = remove_cvref_t<typename TypeConfig::ODataType>; using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = remove_cvref_t<typename TypeConfig::AScaleDataType>; using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = remove_cvref_t<typename TypeConfig::GScaleDataType>; using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = remove_cvref_t<typename TypeConfig::DScaleDataType>; using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate static constexpr ck_tile::index_t BI_ =
static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
...@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>()
template <typename T> template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0) auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{ {
static_assert(t.get_lengths().size() == 3); assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0]; int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1]; int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2]; int k_ = t.get_lengths()[2];
...@@ -156,7 +156,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -156,7 +156,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType; using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType; using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType; // using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType; using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType; using GScaleDataType = typename TypeConfig::GScaleDataType;
...@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size}); ck_tile::HostTensor<DDataType> d_host({experts, intermediate_size, hidden_size});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size}); ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
...@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// do moe sorting // do moe sorting
if(balance) if(balance)
{ {
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{ {
topk_ids_host.mData[i] = e_cnt; topk_ids_host.mData[i] = e_cnt;
e_cnt++; e_cnt++;
...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else else
{ {
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913); topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
} }
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
...@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
base_str += "=" + prec_o; base_str += "=" + prec_o;
if(fused_quant != 0) if(fused_quant != 0)
{ {
base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")"; base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
} }
return base_str; return base_str;
}(); }();
...@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_moegemm_args args{a_buf.GetDeviceBuffer(), fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_buf.GetDeviceBuffer(), g_perm_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(), d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
? sg_buf.GetDeviceBuffer(), fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
? sd_buf.GetDeviceBuffer(),
fused_quant == 1
? sy_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(), sorted_weight_buf.GetDeviceBuffer(),
...@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
intermediate_size, intermediate_size,
num_tokens, tokens,
experts, experts,
stride }; topk,
stride};
float ave_time = fused_moegemm( float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -473,50 +472,24 @@ int main(int argc, char* argv[]) ...@@ -473,50 +472,24 @@ int main(int argc, char* argv[])
return -1; return -1;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
if(prec_o == "auto") std::string prec_kw = arg_parser.get_str("prec_kw");
{ prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_o = prec_i; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
} prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
if(prec_sx == "auto") prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32") if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
} }
return -3; return -3;
......
...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax) ...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d) add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(15_fused_moe)
...@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global, ...@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global,
CK_TILE_DEVICE void CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x) atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type; // using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global, ...@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global,
static_assert(get_address_space() == address_space_enum::global, "only support global mem"); static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>, amd_buffer_atomic_add_raw<remove_cvref_t<T>,
......
...@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, ...@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/** /**
* @brief Loads a tile of data using inline assembly. * @brief Loads a tile of data using inline assembly.
* *
......
...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution ...@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution
0, 0,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}, bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>); bool_constant<pre_nop>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......
...@@ -509,6 +509,64 @@ struct tile_window_linear ...@@ -509,6 +509,64 @@ struct tile_window_linear
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
......
...@@ -84,6 +84,7 @@ template <typename BottomTensorView_, ...@@ -84,6 +84,7 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
......
...@@ -37,7 +37,7 @@ struct DeviceMem ...@@ -37,7 +37,7 @@ struct DeviceMem
mpDeviceBuf = nullptr; mpDeviceBuf = nullptr;
} }
} }
template <T> template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes()) DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{ {
if(mMemSize != 0) if(mMemSize != 0)
...@@ -109,18 +109,23 @@ struct DeviceMem ...@@ -109,18 +109,23 @@ struct DeviceMem
// construct a host tensor with type T // construct a host tensor with type T
template <typename T> template <typename T>
HostTensor<T> ToHost(std::size_t cpySize = mMemSize) HostTensor<T> ToHost(std::size_t cpySize)
{ {
// TODO: host tensor could be slightly larger than the device tensor // TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer // we just copy all data from GPU buffer
std::size_t host_elements = std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
(cpySize + sizeof(T) - 1) / sizeof(T) HostTensor<T> h_({host_elements}); HostTensor<T> h_({host_elements});
if(mpDeviceBuf) if(mpDeviceBuf)
{ {
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
} }
return h_; return h_;
} }
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const void SetZero() const
{ {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// [indexing implementation-1] // [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices // using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices // each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5 // e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4 // tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs ...@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
}; };
...@@ -114,11 +114,11 @@ struct FusedMoeGemmKernel ...@@ -114,11 +114,11 @@ struct FusedMoeGemmKernel
using Partitioner = remove_cvref_t<Partitioner_>; using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0); // static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType; using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType; using GDataType = typename Pipeline::Problem::GDataType;
...@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel ...@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
return "";
// clang-format on // clang-format on
} }
...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel ...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
}; };
...@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel ...@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel
return bit_cast<Kargs>(hargs); return bit_cast<Kargs>(hargs);
} }
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{ {
return TilePartitioner::GridSize(num_cu, blocks_per_cu); constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * (block_m - 1);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize()); // return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return Pipeline::GetSmemSize();
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
...@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel ...@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0; index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0; index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0 index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0 index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
...@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel ...@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index // note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, hidden_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
TilePartitioner{}(num_sorted_tiles, kargs.intermediate_size); Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles) if(sorted_tile_id >= num_sorted_tiles)
return; return;
...@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel ...@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size // index along intermediate_size
index_t hidden_idx = __builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_N0); // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
index_t hidden_idx_nr = // BlockShape::Block_N0);
__builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_Nr0); index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0; const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
...@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel ...@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel
const auto a_window_ = make_tile_window( const auto a_window_ = make_tile_window(
a_gather_view_, a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<Pipeline::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0}); {0, 0});
return a_window_; return a_window_;
}(); }();
...@@ -274,60 +279,58 @@ struct FusedMoeGemmKernel ...@@ -274,60 +279,58 @@ struct FusedMoeGemmKernel
const auto g_window = [&]() { const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 + static_cast<long_index_t>(expert_id) * expert_stride_0 +
hidden_idx_nr * kr_0 * BlockShape::Block_W0; interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>( const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(nr_0, kr_0, number<Pipeline::Block_W0>{}), make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1), make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{}, number<Pipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_view_1_ = const auto g_view_1_ =
pad_tensor_view(g_view_, pad_tensor_view(g_view_,
make_tuple(number<Pipeline::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<Pipeline::Block_W0>{}), number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{}); sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_, const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<Pipeline::Block_W0>{}), number<BlockShape::Block_W0>{}),
{0, 0, 0}); {0, 0, 0});
return g_window_; return g_window_;
}(); }();
const auto d_window = [&]() { const auto d_window = [&]() {
const DDataType* d_ptr = [&]() { const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 + static_cast<long_index_t>(expert_id) * expert_stride_1 +
hidden_idx_nr* BlockShape::Block_W1; interm_idx_nr * BlockShape::Block_W1;
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm // note interm_idx_nr is along the gemm-k dim of 2nd gemm
}();
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr, d_ptr,
make_tuple(nr_1, kr_1, Pipeline::Block_W1), make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1), make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
const auto d_view_1_ = const auto d_view_1_ =
pad_tensor_view(d_view_, pad_tensor_view(d_view_,
make_tuple(number<Pipeline::kBlockNr_1>{}, make_tuple(number<BlockShape::Block_Nr1>{},
number<Pipeline::kBlockKr_1>{}, number<BlockShape::Block_Kr1>{},
number<Pipeline::Block_W1>{}), number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{}); sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_, const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<Pipeline::kBlockNr_1>{}, make_tuple(number<BlockShape::Block_Nr1>{},
number<Pipeline::kBlockKr_1>{}, number<BlockShape::Block_Kr1>{},
number<Pipeline::Block_W1>{}), number<BlockShape::Block_W1>{}),
{0, 0, 0}); {0, 0, 0});
return d_window_; return d_window_;
}(); }();
auto o_window = [&]() { auto o_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr); ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global, auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>( memory_operation_enum::atomic_add>(
o_ptr, o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.num_tokens, kargs.hidden_size),
...@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel ...@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel
number<1>{}); number<1>{});
// gather is here // gather is here
const auto o_scatter_view_ = transform_tensor_view( auto o_scatter_view_ = transform_tensor_view(
o_view_, o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)), make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_window_ = make_tile_window( auto o_window_ = make_tile_window(
o_scatter_view_, o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<Pipeline::Block_N1>{}), make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0}); {0, 0});
return o_window_; return o_window_;
}(); }();
......
...@@ -58,14 +58,15 @@ struct FusedMoeGemmShape ...@@ -58,14 +58,15 @@ struct FusedMoeGemmShape
static constexpr index_t NumWarps = static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{}); reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{})); static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(numner<0>{}); static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(numner<1>{}); static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(numner<2>{}); static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{}); static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{}); static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
...@@ -83,12 +84,12 @@ struct FusedMoeGemmShape ...@@ -83,12 +84,12 @@ struct FusedMoeGemmShape
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpTile_1::at(numner<0>{}); static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpTile_1::at(numner<1>{}); static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpTile_1::at(numner<2>{}); static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t Warp_M1 = WarpPerBlock_1::at(number<0>{}); static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpPerBlock_1::at(number<1>{}); static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpPerBlock_1::at(number<2>{}); static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1; static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1; static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
...@@ -119,6 +120,6 @@ struct FusedMoeGemmShape ...@@ -119,6 +120,6 @@ struct FusedMoeGemmShape
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1); static_assert(Block_W0 == Block_W1);
static_assert(Block_Nr0 == Block_Kr1); // static_assert(Block_Nr0 == Block_Kr1);
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear ...@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear
// FusedMoeGemmShape // FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>; using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "eh"; // expert x hidden static constexpr const char* name = "lin";
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/, CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/)) ck_tile::index_t /*intermediate_size*/)
{ {
index_t i_n = blockIdx.x; index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y; index_t i_m = blockIdx.y;
...@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear ...@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear
return ck_tile::make_tuple(i_m, i_n); return ck_tile::make_tuple(i_m, i_n);
} }
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size) CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
{ {
// TODO: this may need tuning // TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0); index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, BlockShape::Block_N0); index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
return dim3(ns, ms, 1); return dim3(ns, ms, 1);
} }
}; };
......
...@@ -5,10 +5,7 @@ ...@@ -5,10 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -40,19 +37,19 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -40,19 +37,19 @@ struct FusedMoeGemmPipeline_Flatmm
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType; using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType; using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType; using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType; using YDataType = typename Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits; using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly; static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize; static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>(); static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>(); static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::GetAlignment_D<Problem>(); static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::GetAlignment_O<Problem>(); static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -66,12 +63,15 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -66,12 +63,15 @@ struct FusedMoeGemmPipeline_Flatmm
static constexpr const char* name = "fused_moe_flatmm"; static constexpr const char* name = "fused_moe_flatmm";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// TODO: there are multiple buffers // TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{ {
return Policy<Problem>::GetSmemSize_A(); return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -95,21 +95,22 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -95,21 +95,22 @@ struct FusedMoeGemmPipeline_Flatmm
const GWindow& g_window_, const GWindow& g_window_,
const DWindow& d_window_, const DWindow& d_window_,
OWindow& o_window_, OWindow& o_window_,
TopkWeightDataType topk_weight, TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t hidden_size, index_t hidden_size,
index_t intermediate_size) index_t intermediate_size)
{ {
_Pragma("clang diagnostic push") _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") constexpr auto NEG1 = constexpr auto NEG1 = number<-1>{};
number<-1>{} : constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
constexpr auto TRUE = bool_constant<true>{}; constexpr auto TRUE = bool_constant<true>{};
constexpr auto FALSE = bool_constant<false>{}; constexpr auto FALSE = bool_constant<false>{};
CK_TILE_LDS_ADDR void* smem_0 = smem; CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR void* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR void*>( CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) + Pipeline::GetSmemSize_A()); reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
Policy::template GetSmemSize_A<Problem>());
auto g_view = g_window_.get_bottom_tensor_view(); auto g_view = g_window_.get_bottom_tensor_view();
...@@ -120,8 +121,8 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -120,8 +121,8 @@ struct FusedMoeGemmPipeline_Flatmm
} }
else else
{ {
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0; index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0; index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
const GDataType* g_ptr = const GDataType* g_ptr =
g_window_.get_bottom_tensor_view().get_buffer_view().p_data_; g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
...@@ -153,23 +154,28 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -153,23 +154,28 @@ struct FusedMoeGemmPipeline_Flatmm
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>()); o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
using g_thread_type = decltype(load_tile(g_win)); using g_thread_type = decltype(load_tile(g_win));
using u_thread_type = decltype(load_tile(u_win));
using d_thread_type = decltype(load_tile(d_win)); using d_thread_type = decltype(load_tile(d_win));
// using WarpGemm0 = Policy::template GetWarpGemm0<Problem>();
// using WarpGemm1 = Policy::template GetWarpGemm1<Problem>();
// auto warp_gemm_0 = WarpGemm0{};
// auto warp_gemm_1 = WarpGemm1{};
// issues_warps_lanes // issues_warps_lanes
auto a_sst_win0 = auto a_sst_win0 =
make_tile_window_linear(make_tensor_view<address_space_enum::lds>( make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()), smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(), Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0}); {0, 0, 0});
auto a_sst_win1 = auto a_sst_win1 =
make_tile_window_linear(make_tensor_view<address_space_enum::lds>( make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()), smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(), Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0}); {0, 0, 0});
// m*k // m*k
auto a_sld_win0 = [&]() { auto a_sld_win0 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
constexpr auto a_outer_dstr_enc = tile_distribution_encoding< constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>, tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
...@@ -185,11 +191,12 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -185,11 +191,12 @@ struct FusedMoeGemmPipeline_Flatmm
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()), smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(), Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0}, {0, 0},
a_block_dstr_encode); make_static_tile_distribution(a_block_dstr_encode));
}(); }();
// m*k // m*k
auto a_sld_win1 = [&]() { auto a_sld_win1 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
constexpr auto a_outer_dstr_enc = tile_distribution_encoding< constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>, tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
...@@ -205,32 +212,30 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -205,32 +212,30 @@ struct FusedMoeGemmPipeline_Flatmm
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()), smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(), Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0}, {0, 0},
a_block_dstr_encode); make_static_tile_distribution(a_block_dstr_encode));
}(); }();
auto bridge_sst_win = [&]() { auto bridge_sst_win = [&]() {
return make_tile_window_linear( return make_tile_window(
make_tensor_view<address_space_enum::lds>( make_tensor_view<address_space_enum::lds>(
smem, Policy::template MakeBridgeLdsStoreDesc<Problem>()), reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(), Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0}); {0, 0});
}; }();
auto bridge_sld_win = [&]() { auto bridge_sld_win = [&]() {
return make_tile_window_linear( return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>( make_tensor_view<address_space_enum::lds>(
smem, Policy::template MakeBridgeLdsLoadDesc<Problem>()), reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsLoadDesc<Problem>()),
Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(), Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
{0, 0}, {0, 0},
Policy::tepmlate MakeYTileDistribution<Problem>()); Policy::template MakeYTileDistribution<Problem>());
}; }();
// also OK with C array, 2 register buffer // also OK with C array, 2 register buffer
statically_indexed_array<g_thread_type, 2> gs; statically_indexed_array<g_thread_type, 2> gs;
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
constexpr auto issues_a = number<a_win.get_num_of_access()>{}; constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{}; constexpr auto issues_g = number<g_win.get_num_of_access()>{};
...@@ -242,8 +247,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -242,8 +247,10 @@ struct FusedMoeGemmPipeline_Flatmm
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1>{}; number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1>{};
constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{}; constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const index_t num_blocks_k0 = (hidden_size + Problem::Block_K0 - 1) / Problem::Block_K0; const index_t num_blocks_k0 =
const index_t num_blocks_n1 = (hidden_size + Problem::Block_N1 - 1) / Problem::Block_N1; (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
const index_t num_blocks_n1 =
(hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
using a_thread_type = decltype(load_tile(a_sld_win0)); using a_thread_type = decltype(load_tile(a_sld_win0));
statically_indexed_array<a_thread_type, 2> as; statically_indexed_array<a_thread_type, 2> as;
...@@ -253,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -253,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
{ {
async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
}; };
auto move_a = [&]() { // auto move_a = [&]() {
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}}); // move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
}; // };
auto sld_a = [&](auto& a_, auto& win_, auto i_access) { auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access); load_tile_raw(a_, win_, i_access);
}; };
...@@ -277,40 +284,41 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -277,40 +284,41 @@ struct FusedMoeGemmPipeline_Flatmm
} }
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
}; };
auto move_g = // auto move_g =
[&]() { // [&]() {
move_tile_window(g_win, // move_tile_window(g_win,
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}}); // {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
} statically_indexed_array<d_thread_type, 2> // };
ds; statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>( auto gld_d = [&]<typename PreNop = bool_constant<false>>(
auto& d_, auto i_access, PreNop = {}) auto& d_, auto i_access, PreNop = {})
{ {
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
}; };
auto move_d = [&]() { // auto move_d = [&]() {
// d move along gemm-n // // d move along gemm-n
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}}); // move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
}; // };
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>( auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {}) auto& o_, auto i_access, PreNop = {})
{ {
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
} };
auto acc_0 = MakeCBlockTile_Gemm0<Problem>(); auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
auto acc_1s = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{}); auto acc_1s = generate_tuple(
[&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
// clang-format off // clang-format off
auto gemm_0 = [&]<typename PostNop = bool_constant<false>> auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) { (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::GetWarpGemm0<Problem>(); auto warp_gemm = Policy::template GetWarpGemm0<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>; using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
constexpr auto repeat_m = BlockShape::Repeat_M0; constexpr auto repeat_m = BlockShape::Repeat_M0;
constexpr auto repeat_n = BlockShape::Repeat_N0; // constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0; constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k // loop order n->m->k
constexpr auto i_k = i_access % repeat_k; constexpr auto i_k = i_access % repeat_k;
...@@ -320,11 +328,18 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -320,11 +328,18 @@ struct FusedMoeGemmPipeline_Flatmm
using AWarpTensor = typename WarpGemm::AWarpTensor; using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor; using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{}; constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{}; constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a; AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data( w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros), merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
...@@ -352,11 +367,11 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -352,11 +367,11 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off // clang-format off
auto gemm_1 = [&]<typename PostNop = bool_constant<false>> auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) { (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::GetWarpGemm1<Problem>(); auto warp_gemm = Policy::template GetWarpGemm1<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>; using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
constexpr auto repeat_m = BlockShape::Repeat_M1; constexpr auto repeat_m = BlockShape::Repeat_M1;
constexpr auto repeat_n = BlockShape::Repeat_N1; // constexpr auto repeat_n = BlockShape::Repeat_N1;
constexpr auto repeat_k = BlockShape::Repeat_K1; constexpr auto repeat_k = BlockShape::Repeat_K1;
// loop order n->m->k // loop order n->m->k
constexpr auto i_k = i_access % repeat_k; constexpr auto i_k = i_access % repeat_k;
...@@ -366,11 +381,18 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -366,11 +381,18 @@ struct FusedMoeGemmPipeline_Flatmm
using AWarpTensor = typename WarpGemm::AWarpTensor; using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor; using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{}; constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{}; constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a; AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data( w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros), merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
...@@ -419,7 +441,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -419,7 +441,7 @@ struct FusedMoeGemmPipeline_Flatmm
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{}); gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_swin_1, number<i_issue / mfma_per_sld_a>{}); sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
}); });
// compute buffer 1 // compute buffer 1
...@@ -432,14 +454,14 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -432,14 +454,14 @@ struct FusedMoeGemmPipeline_Flatmm
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{}); gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I0], a_swin_0, number<i_issue / mfma_per_sld_a>{}); sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
}); });
}; };
auto pipeline_gemm0_tail = [&]() { auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0; constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0; constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
constexpr index_t mfma_per_gld_a = total_loops / issues_a; // constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a; constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// compute buffer 0 // compute buffer 0
...@@ -452,7 +474,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -452,7 +474,7 @@ struct FusedMoeGemmPipeline_Flatmm
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{}); // gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_swin_1, number<i_issue / mfma_per_sld_a>{}); sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
}); });
// compute buffer 1 // compute buffer 1
...@@ -461,14 +483,14 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -461,14 +483,14 @@ struct FusedMoeGemmPipeline_Flatmm
}); });
}; };
auto y = Policy::MakeYBlockTile<Problem>(); auto y = Policy::template MakeYBlockTile<Problem>();
auto pipeline_bridge = [&]() { auto pipeline_bridge = [&]() {
// cast to Y data // cast to Y data
auto y_pre = cast_tile<YDataType>(acc_0); auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre); store_tile(bridge_sst_win, y_pre);
clear_tile(acc_1s(I0)); clear_tile(acc_1s(I0));
wave_barrier(); // wave_barrier();
load_tile(y, bridge_sld_win); load_tile(y, bridge_sld_win);
clear_tile(acc_1s(I1)); clear_tile(acc_1s(I1));
}; };
...@@ -481,7 +503,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -481,7 +503,7 @@ struct FusedMoeGemmPipeline_Flatmm
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
...@@ -494,7 +516,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -494,7 +516,7 @@ struct FusedMoeGemmPipeline_Flatmm
// compute buffer 0 // compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
...@@ -511,7 +533,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -511,7 +533,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t mfma_per_gld_d = total_loops / issues_d; constexpr index_t mfma_per_gld_d = total_loops / issues_d;
// compute buffer 0 // compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
}); });
...@@ -522,7 +544,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -522,7 +544,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t mfma_per_atm_o = total_loops / issues_o; constexpr index_t mfma_per_atm_o = total_loops / issues_o;
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
...@@ -542,7 +564,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -542,7 +564,7 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off // clang-format off
gld_a(a_sst_win0, NEG1, TRUE); gld_a(a_sst_win0, NEG1, TRUE);
gld_g(gs[I0], NEG1, TRUE); gld_g(gs[I0], NEG1, TRUE);
sld_a(as[I0], a_swin_0, NEG1); sld_a(as[I0], a_sld_win0, NEG1);
gld_a(a_sst_win1, NEG1); gld_a(a_sst_win1, NEG1);
clear_tile(acc_0); clear_tile(acc_0);
...@@ -561,7 +583,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -561,7 +583,7 @@ struct FusedMoeGemmPipeline_Flatmm
const index_t iters_1 = (num_blocks_n1 - 2) / 2; const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0; index_t i_1 = 0;
pipeline_gemm1_head(); pipeline_gemm1_head();
while(i_0 < iters_0) while(i_1 < iters_1)
{ {
pipeline_gemm1(); pipeline_gemm1();
} }
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -21,8 +22,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -21,8 +22,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{ {
// using async // using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords(); constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
static constexpr index_t data_bytes = sizeof(typename Problem::ADataType); constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
} }
...@@ -30,8 +31,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -30,8 +31,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{ {
static constexpr index_t copy_bytes = [&]() { return 16; }(); constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType); constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
} }
...@@ -39,8 +40,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -39,8 +40,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D() CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{ {
static constexpr index_t copy_bytes = [&]() { return 16; }(); constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType); constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0); static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes; return copy_bytes / data_bytes;
} }
...@@ -69,7 +70,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -69,7 +70,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{ {
// TODO: this is for 3d layout // TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<typename Problem::DataType_>); return 16 / sizeof(remove_cvref_t<DataType_>);
} }
template <typename Problem> template <typename Problem>
...@@ -78,6 +79,14 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -78,6 +79,14 @@ struct FusedMoeGemmPipelineFlatmmPolicy
return GetSmemKPack<typename Problem::ADataType>(); return GetSmemKPack<typename Problem::ADataType>();
} }
// used for bridge LDS shuffle
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
{
// TODO: this should match mfma layout
return 16 / sizeof(typename Problem::YDataType);
}
#if 0 #if 0
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape() CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape()
...@@ -222,28 +231,6 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -222,28 +231,6 @@ struct FusedMoeGemmPipelineFlatmmPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
constexpr index_t KPerBlock = Problem::BlockShape::Block_K0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
#if 0 #if 0
// Caution: this will require global memory pre-shuffled to follow the mfma layout // Caution: this will require global memory pre-shuffled to follow the mfma layout
template <index_t NPerBlock, template <index_t NPerBlock,
...@@ -356,22 +343,44 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -356,22 +343,44 @@ struct FusedMoeGemmPipelineFlatmmPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{ {
// A async->LDS // A async->LDS
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0; constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t BlockSize = Problem::BlockShape::BlockSize; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % kVector == 0); static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / kVector; // how many thread loading K constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize) if constexpr(LanesPerK >= warpSize)
{ {
// need multiple waves to load K // need multiple waves to load K
...@@ -391,9 +400,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -391,9 +400,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<wavesPerK>{}, // k0 number<wavesPerK>{}, // k0
number<warpSize>{}, // k1 number<warpSize>{}, // k1
number<KVector>{}), // k2 number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0 make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1 number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0 number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1 number<KVector>{}, // k1
number<1>{}), // k2 number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store) number<KVector>{}, // lds store vector(actually no explicit store)
...@@ -424,9 +433,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -424,9 +433,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2 number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0 number<LanesPerK>{}, // k0
number<KVector>{}), // k1 number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0 make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1 number<Block_K>{}, // m1
number<warpSize * KVector + kPad>{}, // m2 number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0 number<KVector>{}, // k0
number<1>{}), // k1 number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store) number<KVector>{}, // lds store vector(actually no explicit store)
...@@ -457,16 +466,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -457,16 +466,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// 2). return discriptor is in NxK 2d layout // 2). return discriptor is in NxK 2d layout
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0; constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t BlockSize = Problem::BlockShape::BlockSize; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % kVector == 0); static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / kVector; // how many thread loading K constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize) if constexpr(LanesPerK >= warpSize)
{ {
// need multiple waves to load K // need multiple waves to load K
...@@ -486,9 +495,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -486,9 +495,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<wavesPerK>{}, // k0 number<wavesPerK>{}, // k0
number<warpSize>{}, // k1 number<warpSize>{}, // k1
number<KVector>{}), // k2 number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0 make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1 number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0 number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1 number<KVector>{}, // k1
number<1>{}), // k2 number<1>{}), // k2
number<KPack>{}, // lds load vector number<KPack>{}, // lds load vector
...@@ -519,9 +528,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -519,9 +528,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2 number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0 number<LanesPerK>{}, // k0
number<KVector>{}), // k1 number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0 make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1 number<Block_K>{}, // m1
number<warpSize * KVector + kPad>{}, // m2 number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0 number<KVector>{}, // k0
number<1>{}), // k1 number<1>{}), // k1
number<KPack>{}, // lds load vector number<KPack>{}, // lds load vector
...@@ -546,13 +555,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -546,13 +555,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps constexpr index_t KPad = KVector; // pad between warps
constexpr auto desc = = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + kPad>{}, number<1>{}), make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KPack>{}, number<KVector>{},
number<1>{}); number<1>{});
return desc; return desc;
} }
...@@ -563,13 +572,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -563,13 +572,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps constexpr index_t KPad = KVector; // pad between warps
constexpr auto desc = = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + kPad>{}, number<1>{}), make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KPack>{}, number<KVector>{},
number<1>{}); number<1>{});
return desc; return desc;
} }
...@@ -582,16 +591,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -582,16 +591,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// TODO: this is ugly // TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav; constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
// TODO: ugly // TODO: ugly
if constexpr(std::is_same_v<Problem::ADataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<Problem::GDataType, ck_tile::bf16_t> && S_::Warp_M0 == 32 && std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_N0 == 32 && S_::Warp_K0 == 16) S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K<wg_ctrl>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{}; 2>>{};
} }
else if constexpr(std::is_same_v<Problem::ADataType, ck_tile::int8_t> && else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<Problem::GDataType, ck_tile::int8_t> && std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
...@@ -606,16 +615,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -606,16 +615,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vva; constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vva;
// TODO: ugly // TODO: ugly
if constexpr(std::is_same_v<Problem::YDataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<Problem::DDataType, ck_tile::bf16_t> && S_::Warp_M0 == 32 && std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_N0 == 32 && S_::Warp_K0 == 16) S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K<wg_ctrl>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{}; 2>>{};
} }
else if constexpr(std::is_same_v<Problem::YDataType, ck_tile::int8_t> && else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<Problem::DDataType, ck_tile::int8_t> && std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
...@@ -625,11 +634,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -625,11 +634,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm0() const CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
{ {
using S_ = remove_cvref_t<typename Problem::BlockShape>; using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>; using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
using CDataType = WarpGemm::WarpGemm; using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding = constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
...@@ -648,11 +657,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -648,11 +657,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm1() const CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{ {
using S_ = remove_cvref_t<typename Problem::BlockShape>; using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>; using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = WarpGemm::CDataType; using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding = constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
...@@ -672,11 +681,10 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -672,11 +681,10 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// this is used as A matrix for 2nd gemm // this is used as A matrix for 2nd gemm
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeYTileDistribution() const CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{ {
using S_ = remove_cvref_t<typename Problem::BlockShape>; using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>; using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using YDataType = typename Problem::YDataType;
// TODO: all waves a along different N, but same M // TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc = constexpr auto y_outer_dstr_enc =
...@@ -694,54 +702,12 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -694,54 +702,12 @@ struct FusedMoeGemmPipelineFlatmmPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeYBlockTile() const CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
{ {
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>(); constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
auto y_block_tensor = make_static_distributed_tensor<CDataType>(y_block_dstr); auto y_block_tensor =
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor; return y_block_tensor;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
constexpr index_t KPerBlock = Problem::BlockShape::Block_K0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1;
constexpr index_t KPerBlock = Problem::BlockShape::kBlockK_1;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -35,9 +35,9 @@ struct WarpGemmImpl ...@@ -35,9 +35,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{ {
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> && static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>); detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>; using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>; using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
...@@ -85,8 +85,8 @@ struct WarpGemmImpl ...@@ -85,8 +85,8 @@ struct WarpGemmImpl
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{ {
using CTensor = CWarpTensor; using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>); detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c; CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment