Commit cf646183 authored by carlushuang's avatar carlushuang
Browse files

compile OK

parent 70fa98ad
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/fused_moe.hpp"
#include <string> #include <string>
// this is only a convenient structure for creating an example // this is only a convenient structure for creating an example
...@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename ...@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
struct FusedMoeGemmTypeConfig; struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{ {
using ADataType = ck_tile::bf16_t; using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t; using GDataType = ck_tile::bf16_t;
...@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ...@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t,
}; };
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{ {
using ADataType = ck_tile::int8_t; using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t; using GDataType = ck_tile::int8_t;
...@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ...@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t,
}; };
// runtime args // runtime args
struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{ {
}; };
......
...@@ -3,33 +3,25 @@ ...@@ -3,33 +3,25 @@
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "fused_moegemm.hpp" #include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j` // Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_> template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a); float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s) float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{ {
template <ck_tile::index_t... Is> // clang-format off
using S = ck_tile::sequence<Is...>;
float r = -1; float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
ck_tile::bf16_t,
ck_tile::bf16_t,
float,
float,
float,
float,
S<32, 512, 128, 128>,
S<4, 1, 1>,
S<32, 32, 16>,
1,
0>;
fused_moegemm_<t_>(s, a); fused_moegemm_<t_>(s, a);
} }
// clang-format on
return r; return r;
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp" #include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp" #include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_> template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{ {
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>; using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0, using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0, typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, using f_problem =
typename Ts_::GDataType, ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::DDataType, typename Ts_::GDataType,
typename Ts_::AccDataType, typename Ts_::DDataType,
typename Ts_::ODataType, typename Ts_::AccDataType,
typename Ts_::AScaleDataType, typename Ts_::ODataType,
typename Ts_::GScaleDataType, typename Ts_::AScaleDataType,
typename Ts_::DScaleDataType, typename Ts_::GScaleDataType,
typename Ts_::YSmoothScaleDataType, typename Ts_::DScaleDataType,
typename Ts_::TopkWeightDataType, typename Ts_::YSmoothScaleDataType,
typename Ts_::IndexDataType, typename Ts_::TopkWeightDataType,
ck_tile::Gelu, // TODO: hardcoded typename Ts_::IndexDataType,
f_shape, ck_tile::element_wise::Gelu, // TODO: hardcoded
f_traits> f_shape,
f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a); const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize(); constexpr dim3 blocks = f_kernel::BlockSize();
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
...@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal
{ {
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = remove_cvref_t<typename TypeConfig::ADataType>; using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = remove_cvref_t<typename TypeConfig::GDataType>; using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = remove_cvref_t<typename TypeConfig::DDataType>; using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = remove_cvref_t<typename TypeConfig::AccDataType>; using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = remove_cvref_t<typename TypeConfig::ODataType>; using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = remove_cvref_t<typename TypeConfig::AScaleDataType>; using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = remove_cvref_t<typename TypeConfig::GScaleDataType>; using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = remove_cvref_t<typename TypeConfig::DScaleDataType>; using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate static constexpr ck_tile::index_t BI_ =
static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
...@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>()
template <typename T> template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0) auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{ {
static_assert(t.get_lengths().size() == 3); assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0]; int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1]; int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2]; int k_ = t.get_lengths()[2];
...@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp; ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType; using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType; using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType; // using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType; using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType; using GScaleDataType = typename TypeConfig::GScaleDataType;
...@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size}); ck_tile::HostTensor<DDataType> d_host({experts, intermediate_size, hidden_size});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size}); ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
...@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// do moe sorting // do moe sorting
if(balance) if(balance)
{ {
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{ {
topk_ids_host.mData[i] = e_cnt; topk_ids_host.mData[i] = e_cnt;
e_cnt++; e_cnt++;
...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else else
{ {
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913); topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
} }
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
...@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
base_str += "=" + prec_o; base_str += "=" + prec_o;
if(fused_quant != 0) if(fused_quant != 0)
{ {
base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")"; base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
} }
return base_str; return base_str;
}(); }();
...@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_moegemm_args args{a_buf.GetDeviceBuffer(), fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_buf.GetDeviceBuffer(), g_perm_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(), d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
? sg_buf.GetDeviceBuffer(), fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
? sd_buf.GetDeviceBuffer(),
fused_quant == 1
? sy_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(), sorted_weight_buf.GetDeviceBuffer(),
...@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
intermediate_size, intermediate_size,
num_tokens, tokens,
experts, experts,
stride }; topk,
stride};
float ave_time = fused_moegemm( float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -473,50 +472,24 @@ int main(int argc, char* argv[]) ...@@ -473,50 +472,24 @@ int main(int argc, char* argv[])
return -1; return -1;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
if(prec_o == "auto") std::string prec_kw = arg_parser.get_str("prec_kw");
{ prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_o = prec_i; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
} prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
if(prec_sx == "auto") prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32") if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
} }
return -3; return -3;
......
...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax) ...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d) add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(15_fused_moe)
...@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global, ...@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global,
CK_TILE_DEVICE void CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x) atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type; // using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global, ...@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global,
static_assert(get_address_space() == address_space_enum::global, "only support global mem"); static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>, amd_buffer_atomic_add_raw<remove_cvref_t<T>,
......
...@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, ...@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/** /**
* @brief Loads a tile of data using inline assembly. * @brief Loads a tile of data using inline assembly.
* *
......
...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution ...@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution
0, 0,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}, bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>); bool_constant<pre_nop>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......
...@@ -509,6 +509,64 @@ struct tile_window_linear ...@@ -509,6 +509,64 @@ struct tile_window_linear
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
......
...@@ -84,6 +84,7 @@ template <typename BottomTensorView_, ...@@ -84,6 +84,7 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
......
...@@ -37,7 +37,7 @@ struct DeviceMem ...@@ -37,7 +37,7 @@ struct DeviceMem
mpDeviceBuf = nullptr; mpDeviceBuf = nullptr;
} }
} }
template <T> template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes()) DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{ {
if(mMemSize != 0) if(mMemSize != 0)
...@@ -109,18 +109,23 @@ struct DeviceMem ...@@ -109,18 +109,23 @@ struct DeviceMem
// construct a host tensor with type T // construct a host tensor with type T
template <typename T> template <typename T>
HostTensor<T> ToHost(std::size_t cpySize = mMemSize) HostTensor<T> ToHost(std::size_t cpySize)
{ {
// TODO: host tensor could be slightly larger than the device tensor // TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer // we just copy all data from GPU buffer
std::size_t host_elements = std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
(cpySize + sizeof(T) - 1) / sizeof(T) HostTensor<T> h_({host_elements}); HostTensor<T> h_({host_elements});
if(mpDeviceBuf) if(mpDeviceBuf)
{ {
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
} }
return h_; return h_;
} }
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const void SetZero() const
{ {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// [indexing implementation-1] // [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices // using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices // each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5 // e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4 // tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs ...@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
}; };
...@@ -111,14 +111,14 @@ struct FusedMoeGemmHostArgs ...@@ -111,14 +111,14 @@ struct FusedMoeGemmHostArgs
template <typename Partitioner_, typename Pipeline_, typename Epilogue_> template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmKernel struct FusedMoeGemmKernel
{ {
using Partitioner = remove_cvref_t<Partitioner_>; using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0); // static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType; using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType; using GDataType = typename Pipeline::Problem::GDataType;
...@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel ...@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
return "";
// clang-format on // clang-format on
} }
...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel ...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
}; };
...@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel ...@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel
return bit_cast<Kargs>(hargs); return bit_cast<Kargs>(hargs);
} }
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{ {
return TilePartitioner::GridSize(num_cu, blocks_per_cu); constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * (block_m - 1);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize()); // return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return Pipeline::GetSmemSize();
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
...@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel ...@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0; index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0; index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0 index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0 index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
...@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel ...@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index // note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, hidden_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
TilePartitioner{}(num_sorted_tiles, kargs.intermediate_size); Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles) if(sorted_tile_id >= num_sorted_tiles)
return; return;
...@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel ...@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size // index along intermediate_size
index_t hidden_idx = __builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_N0); // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
index_t hidden_idx_nr = // BlockShape::Block_N0);
__builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_Nr0); index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0; const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
...@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel ...@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel
const auto a_window_ = make_tile_window( const auto a_window_ = make_tile_window(
a_gather_view_, a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<Pipeline::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0}); {0, 0});
return a_window_; return a_window_;
}(); }();
...@@ -274,61 +279,59 @@ struct FusedMoeGemmKernel ...@@ -274,61 +279,59 @@ struct FusedMoeGemmKernel
const auto g_window = [&]() { const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 + static_cast<long_index_t>(expert_id) * expert_stride_0 +
hidden_idx_nr * kr_0 * BlockShape::Block_W0; interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>( const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(nr_0, kr_0, number<Pipeline::Block_W0>{}), make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1), make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{}, number<Pipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_view_1_ = const auto g_view_1_ =
pad_tensor_view(g_view_, pad_tensor_view(g_view_,
make_tuple(number<Pipeline::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<Pipeline::Block_W0>{}), number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{}); sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_, const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<Pipeline::Block_W0>{}), number<BlockShape::Block_W0>{}),
{0, 0, 0}); {0, 0, 0});
return g_window_; return g_window_;
}(); }();
const auto d_window = [&]() { const auto d_window = [&]() {
const DDataType* d_ptr = [&]() { const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
reinterpret_cast<const DDataType*>(kargs.d_ptr) + static_cast<long_index_t>(expert_id) * expert_stride_1 +
static_cast<long_index_t>(expert_id) * expert_stride_1 + interm_idx_nr * BlockShape::Block_W1;
hidden_idx_nr* BlockShape::Block_W1; // note interm_idx_nr is along the gemm-k dim of 2nd gemm
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr, d_ptr,
make_tuple(nr_1, kr_1, Pipeline::Block_W1), make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1), make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
const auto d_view_1_ = const auto d_view_1_ =
pad_tensor_view(d_view_, pad_tensor_view(d_view_,
make_tuple(number<Pipeline::kBlockNr_1>{}, make_tuple(number<BlockShape::Block_Nr1>{},
number<Pipeline::kBlockKr_1>{}, number<BlockShape::Block_Kr1>{},
number<Pipeline::Block_W1>{}), number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{}); sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_, const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<Pipeline::kBlockNr_1>{}, make_tuple(number<BlockShape::Block_Nr1>{},
number<Pipeline::kBlockKr_1>{}, number<BlockShape::Block_Kr1>{},
number<Pipeline::Block_W1>{}), number<BlockShape::Block_W1>{}),
{0, 0, 0}); {0, 0, 0});
return d_window_; return d_window_;
}(); }();
auto o_window = [&]() { auto o_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr); ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global, auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>( memory_operation_enum::atomic_add>(
o_ptr, o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1), make_tuple(kargs.stride_token, 1),
...@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel ...@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel
number<1>{}); number<1>{});
// gather is here // gather is here
const auto o_scatter_view_ = transform_tensor_view( auto o_scatter_view_ = transform_tensor_view(
o_view_, o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)), make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_window_ = make_tile_window( auto o_window_ = make_tile_window(
o_scatter_view_, o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<Pipeline::Block_N1>{}), make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0}); {0, 0});
return o_window_; return o_window_;
}(); }();
......
...@@ -58,14 +58,15 @@ struct FusedMoeGemmShape ...@@ -58,14 +58,15 @@ struct FusedMoeGemmShape
static constexpr index_t NumWarps = static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{}); reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{})); static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(numner<0>{}); static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(numner<1>{}); static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(numner<2>{}); static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{}); static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{}); static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
...@@ -83,12 +84,12 @@ struct FusedMoeGemmShape ...@@ -83,12 +84,12 @@ struct FusedMoeGemmShape
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpTile_1::at(numner<0>{}); static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpTile_1::at(numner<1>{}); static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpTile_1::at(numner<2>{}); static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t Warp_M1 = WarpPerBlock_1::at(number<0>{}); static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpPerBlock_1::at(number<1>{}); static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpPerBlock_1::at(number<2>{}); static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1; static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1; static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
...@@ -119,6 +120,6 @@ struct FusedMoeGemmShape ...@@ -119,6 +120,6 @@ struct FusedMoeGemmShape
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1); static_assert(Block_W0 == Block_W1);
static_assert(Block_Nr0 == Block_Kr1); // static_assert(Block_Nr0 == Block_Kr1);
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear ...@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear
// FusedMoeGemmShape // FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>; using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "eh"; // expert x hidden static constexpr const char* name = "lin";
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/, CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/)) ck_tile::index_t /*intermediate_size*/)
{ {
index_t i_n = blockIdx.x; index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y; index_t i_m = blockIdx.y;
...@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear ...@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear
return ck_tile::make_tuple(i_m, i_n); return ck_tile::make_tuple(i_m, i_n);
} }
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size) CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
{ {
// TODO: this may need tuning // TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0); index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, BlockShape::Block_N0); index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
return dim3(ns, ms, 1); return dim3(ns, ms, 1);
} }
}; };
......
...@@ -35,9 +35,9 @@ struct WarpGemmImpl ...@@ -35,9 +35,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{ {
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> && static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>); detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>; using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>; using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
...@@ -85,8 +85,8 @@ struct WarpGemmImpl ...@@ -85,8 +85,8 @@ struct WarpGemmImpl
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{ {
using CTensor = CWarpTensor; using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>); detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c; CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment