Commit cf646183 authored by carlushuang's avatar carlushuang
Browse files

compile OK

parent 70fa98ad
......@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <string>
// this is only a convenient structure for creating an example
......@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t;
......@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t,
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t;
......@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t,
};
// runtime args
struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{
};
......
......@@ -3,33 +3,25 @@
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 &&
gate_only == 1)
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::bf16_t,
float,
float,
float,
float,
S<32, 512, 128, 128>,
S<4, 1, 1>,
S<32, 32, 16>,
1,
0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
fused_moegemm_<t_>(s, a);
}
// clang-format on
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::Gelu, // TODO: hardcoded
f_shape,
f_traits>
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::Gelu, // TODO: hardcoded
f_shape,
f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
......@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal
{
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token
static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate
static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden
static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = remove_cvref_t<WarpTile_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = remove_cvref_t<WarpTile_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>()
template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
static_assert(t.get_lengths().size() == 3);
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
......@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
// using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType;
......@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, intermediate_size, hidden_size});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
......@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// do moe sorting
if(balance)
{
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
......@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
{
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913);
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
......@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")";
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
......@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(),
fused_quant != 0
? sg_buf.GetDeviceBuffer(),
fused_quant != 0
? sd_buf.GetDeviceBuffer(),
fused_quant == 1
? sy_buf.GetDeviceBuffer(),
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
......@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
intermediate_size,
num_tokens,
tokens,
experts,
stride };
topk,
stride};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......@@ -473,50 +472,24 @@ int main(int argc, char* argv[])
return -1;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int save_mv = arg_parser.get_int("save_mv");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
// no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3;
......
......@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant)
add_subdirectory(15_fused_moe)
......@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global,
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
......@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global,
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
......
......@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
* @brief Loads a tile of data using inline assembly.
*
......
......@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks;
}
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile
......@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>);
bool_constant<pre_nop>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......
......@@ -509,6 +509,64 @@ struct tile_window_linear
return dst_tensor;
}
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
......
......@@ -84,6 +84,7 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
......
......@@ -37,7 +37,7 @@ struct DeviceMem
mpDeviceBuf = nullptr;
}
}
template <T>
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
......@@ -109,18 +109,23 @@ struct DeviceMem
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize = mMemSize)
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements =
(cpySize + sizeof(T) - 1) / sizeof(T) HostTensor<T> h_({host_elements});
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const
{
......
......@@ -13,7 +13,7 @@
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
......@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
// index_t top_k; // need this?
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
......@@ -111,14 +111,14 @@ struct FusedMoeGemmHostArgs
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
......@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel
{
// sync with generate.py
// clang-format off
return "";
// clang-format on
}
......@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
// index_t top_k; // need this?
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
......@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu)
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return TilePartitioner::GridSize(num_cu, blocks_per_cu);
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * (block_m - 1);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
// return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return Pipeline::GetSmemSize();
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
......@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0;
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0;
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
......@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, hidden_tile_id] =
TilePartitioner{}(num_sorted_tiles, kargs.intermediate_size);
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
......@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
index_t hidden_idx = __builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_N0);
index_t hidden_idx_nr =
__builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_Nr0);
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
......@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<Pipeline::Block_K0>{}),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
......@@ -274,61 +279,59 @@ struct FusedMoeGemmKernel
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
hidden_idx_nr * kr_0 * BlockShape::Block_W0;
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<Pipeline::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1),
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<Pipeline::Block_Nr0>{},
number<Pipeline::Block_Kr0>{},
number<Pipeline::Block_W0>{}),
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{},
number<Pipeline::Block_W0>{}),
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = [&]() {
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
hidden_idx_nr* BlockShape::Block_W1;
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, Pipeline::Block_W1),
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1),
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<Pipeline::kBlockNr_1>{},
number<Pipeline::kBlockKr_1>{},
number<Pipeline::Block_W1>{}),
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<Pipeline::kBlockNr_1>{},
number<Pipeline::kBlockKr_1>{},
number<Pipeline::Block_W1>{}),
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
......@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel
number<1>{});
// gather is here
const auto o_scatter_view_ = transform_tensor_view(
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_window_ = make_tile_window(
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<Pipeline::Block_N1>{}),
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
......
......@@ -58,14 +58,15 @@ struct FusedMoeGemmShape
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(numner<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(numner<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(numner<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
......@@ -83,12 +84,12 @@ struct FusedMoeGemmShape
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpTile_1::at(numner<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpTile_1::at(numner<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpTile_1::at(numner<2>{});
static constexpr index_t Warp_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
......@@ -119,6 +120,6 @@ struct FusedMoeGemmShape
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1);
static_assert(Block_Nr0 == Block_Kr1);
// static_assert(Block_Nr0 == Block_Kr1);
};
} // namespace ck_tile
......@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear
// FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "eh"; // expert x hidden
static constexpr const char* name = "lin";
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/))
ck_tile::index_t /*intermediate_size*/)
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
......@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear
return ck_tile::make_tuple(i_m, i_n);
}
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size)
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
{
// TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, BlockShape::Block_N0);
index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
return dim3(ns, ms, 1);
}
};
......
......@@ -5,10 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
......@@ -40,19 +37,19 @@ struct FusedMoeGemmPipeline_Flatmm
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::GetAlignment_O<Problem>();
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
......@@ -66,12 +63,15 @@ struct FusedMoeGemmPipeline_Flatmm
static constexpr const char* name = "fused_moe_flatmm";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy<Problem>::GetSmemSize_A();
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
......@@ -95,21 +95,22 @@ struct FusedMoeGemmPipeline_Flatmm
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType topk_weight,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
{
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") constexpr auto NEG1 =
number<-1>{} : constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto TRUE = bool_constant<true>{};
constexpr auto FALSE = bool_constant<false>{};
CK_TILE_LDS_ADDR void* smem_0 = smem;
CK_TILE_LDS_ADDR void* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR void*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) + Pipeline::GetSmemSize_A());
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
constexpr auto NEG1 = number<-1>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto TRUE = bool_constant<true>{};
constexpr auto FALSE = bool_constant<false>{};
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
Policy::template GetSmemSize_A<Problem>());
auto g_view = g_window_.get_bottom_tensor_view();
......@@ -120,8 +121,8 @@ struct FusedMoeGemmPipeline_Flatmm
}
else
{
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
const GDataType* g_ptr =
g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
......@@ -153,23 +154,28 @@ struct FusedMoeGemmPipeline_Flatmm
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
using g_thread_type = decltype(load_tile(g_win));
using u_thread_type = decltype(load_tile(u_win));
using d_thread_type = decltype(load_tile(d_win));
// using WarpGemm0 = Policy::template GetWarpGemm0<Problem>();
// using WarpGemm1 = Policy::template GetWarpGemm1<Problem>();
// auto warp_gemm_0 = WarpGemm0{};
// auto warp_gemm_1 = WarpGemm1{};
// issues_warps_lanes
auto a_sst_win0 =
make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
auto a_sst_win1 =
make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_sld_win0 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
......@@ -185,11 +191,12 @@ struct FusedMoeGemmPipeline_Flatmm
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
a_block_dstr_encode);
make_static_tile_distribution(a_block_dstr_encode));
}();
// m*k
auto a_sld_win1 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
......@@ -205,32 +212,30 @@ struct FusedMoeGemmPipeline_Flatmm
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
a_block_dstr_encode);
make_static_tile_distribution(a_block_dstr_encode));
}();
auto bridge_sst_win = [&]() {
return make_tile_window_linear(
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
smem, Policy::template MakeBridgeLdsStoreDesc<Problem>()),
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
};
}();
auto bridge_sld_win = [&]() {
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
smem, Policy::template MakeBridgeLdsLoadDesc<Problem>()),
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsLoadDesc<Problem>()),
Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
{0, 0},
Policy::tepmlate MakeYTileDistribution<Problem>());
};
Policy::template MakeYTileDistribution<Problem>());
}();
// also OK with C array, 2 register buffer
statically_indexed_array<g_thread_type, 2> gs;
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{};
......@@ -242,8 +247,10 @@ struct FusedMoeGemmPipeline_Flatmm
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1>{};
constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const index_t num_blocks_k0 = (hidden_size + Problem::Block_K0 - 1) / Problem::Block_K0;
const index_t num_blocks_n1 = (hidden_size + Problem::Block_N1 - 1) / Problem::Block_N1;
const index_t num_blocks_k0 =
(hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
const index_t num_blocks_n1 =
(hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
using a_thread_type = decltype(load_tile(a_sld_win0));
statically_indexed_array<a_thread_type, 2> as;
......@@ -253,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
{
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
};
auto move_a = [&]() {
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
};
// auto move_a = [&]() {
// move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
// };
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access);
};
......@@ -277,40 +284,41 @@ struct FusedMoeGemmPipeline_Flatmm
}
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
};
auto move_g =
[&]() {
move_tile_window(g_win,
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
} statically_indexed_array<d_thread_type, 2>
ds;
// auto move_g =
// [&]() {
// move_tile_window(g_win,
// {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
// };
statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
auto& d_, auto i_access, PreNop = {})
{
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
};
auto move_d = [&]() {
// d move along gemm-n
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
};
// auto move_d = [&]() {
// // d move along gemm-n
// move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
// };
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {})
{
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
}
};
auto acc_0 = MakeCBlockTile_Gemm0<Problem>();
auto acc_1s = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
auto acc_1s = generate_tuple(
[&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
// clang-format off
auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::GetWarpGemm0<Problem>();
auto warp_gemm = Policy::template GetWarpGemm0<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
constexpr auto repeat_m = BlockShape::Repeat_M0;
constexpr auto repeat_n = BlockShape::Repeat_N0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_k = i_access % repeat_k;
......@@ -320,11 +328,18 @@ struct FusedMoeGemmPipeline_Flatmm
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
......@@ -352,11 +367,11 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::GetWarpGemm1<Problem>();
auto warp_gemm = Policy::template GetWarpGemm1<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
constexpr auto repeat_m = BlockShape::Repeat_M1;
constexpr auto repeat_n = BlockShape::Repeat_N1;
// constexpr auto repeat_n = BlockShape::Repeat_N1;
constexpr auto repeat_k = BlockShape::Repeat_K1;
// loop order n->m->k
constexpr auto i_k = i_access % repeat_k;
......@@ -366,11 +381,18 @@ struct FusedMoeGemmPipeline_Flatmm
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
......@@ -419,7 +441,7 @@ struct FusedMoeGemmPipeline_Flatmm
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_swin_1, number<i_issue / mfma_per_sld_a>{});
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
});
// compute buffer 1
......@@ -432,14 +454,14 @@ struct FusedMoeGemmPipeline_Flatmm
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I0], a_swin_0, number<i_issue / mfma_per_sld_a>{});
sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
});
};
auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
constexpr index_t mfma_per_gld_a = total_loops / issues_a;
// constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// compute buffer 0
......@@ -452,7 +474,7 @@ struct FusedMoeGemmPipeline_Flatmm
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_swin_1, number<i_issue / mfma_per_sld_a>{});
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
});
// compute buffer 1
......@@ -461,14 +483,14 @@ struct FusedMoeGemmPipeline_Flatmm
});
};
auto y = Policy::MakeYBlockTile<Problem>();
auto y = Policy::template MakeYBlockTile<Problem>();
auto pipeline_bridge = [&]() {
// cast to Y data
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
clear_tile(acc_1s(I0));
wave_barrier();
// wave_barrier();
load_tile(y, bridge_sld_win);
clear_tile(acc_1s(I1));
};
......@@ -481,7 +503,7 @@ struct FusedMoeGemmPipeline_Flatmm
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I1], y, ds[I1], i_issue);
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
......@@ -494,7 +516,7 @@ struct FusedMoeGemmPipeline_Flatmm
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I0], y, ds[I0], i_issue);
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
......@@ -511,7 +533,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t mfma_per_gld_d = total_loops / issues_d;
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I0], y, ds[I0], i_issue);
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
});
......@@ -522,7 +544,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t mfma_per_atm_o = total_loops / issues_o;
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_1s[I1], y, ds[I1], i_issue);
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
......@@ -542,7 +564,7 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
gld_a(a_sst_win0, NEG1, TRUE);
gld_g(gs[I0], NEG1, TRUE);
sld_a(as[I0], a_swin_0, NEG1);
sld_a(as[I0], a_sld_win0, NEG1);
gld_a(a_sst_win1, NEG1);
clear_tile(acc_0);
......@@ -561,7 +583,7 @@ struct FusedMoeGemmPipeline_Flatmm
const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0;
pipeline_gemm1_head();
while(i_0 < iters_0)
while(i_1 < iters_1)
{
pipeline_gemm1();
}
......
......@@ -4,8 +4,9 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
......@@ -21,8 +22,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
static constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
......@@ -30,8 +31,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
......@@ -39,8 +40,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
......@@ -69,7 +70,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<typename Problem::DataType_>);
return 16 / sizeof(remove_cvref_t<DataType_>);
}
template <typename Problem>
......@@ -78,6 +79,14 @@ struct FusedMoeGemmPipelineFlatmmPolicy
return GetSmemKPack<typename Problem::ADataType>();
}
// used for bridge LDS shuffle
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
{
// TODO: this should match mfma layout
return 16 / sizeof(typename Problem::YDataType);
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape()
......@@ -222,28 +231,6 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
constexpr index_t KPerBlock = Problem::BlockShape::Block_K0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
#if 0
// Caution: this will require global memory pre-shuffled to follow the mfma layout
template <index_t NPerBlock,
......@@ -356,22 +343,44 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % kVector == 0);
constexpr index_t LanesPerK = Block_K / kVector; // how many thread loading K
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
......@@ -391,9 +400,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
......@@ -424,9 +433,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
......@@ -455,18 +464,18 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % kVector == 0);
constexpr index_t LanesPerK = Block_K / kVector; // how many thread loading K
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
......@@ -486,9 +495,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
......@@ -519,9 +528,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
......@@ -546,13 +555,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr auto desc = =
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + kPad>{}, number<1>{}),
number<KPack>{},
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
......@@ -563,13 +572,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr auto desc = =
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + kPad>{}, number<1>{}),
number<KPack>{},
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
......@@ -582,16 +591,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
// TODO: ugly
if constexpr(std::is_same_v<Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<Problem::GDataType, ck_tile::bf16_t> && S_::Warp_M0 == 32 &&
S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K<wg_ctrl>,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<Problem::GDataType, ck_tile::int8_t> &&
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
......@@ -606,16 +615,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vva;
// TODO: ugly
if constexpr(std::is_same_v<Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<Problem::DDataType, ck_tile::bf16_t> && S_::Warp_M0 == 32 &&
S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K<wg_ctrl>,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<Problem::DDataType, ck_tile::int8_t> &&
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
......@@ -625,11 +634,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm0() const
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
using CDataType = WarpGemm::WarpGemm;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
......@@ -648,11 +657,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm1() const
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = WarpGemm::CDataType;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
......@@ -672,11 +681,10 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeYTileDistribution() const
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using YDataType = typename Problem::YDataType;
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc =
......@@ -694,54 +702,12 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeYBlockTile() const
CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
{
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
auto y_block_tensor = make_static_distributed_tensor<CDataType>(y_block_dstr);
auto y_block_tensor =
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
constexpr index_t KPerBlock = Problem::BlockShape::Block_K0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
{
if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1;
constexpr index_t KPerBlock = Problem::BlockShape::kBlockK_1;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
};
} // namespace ck_tile
......@@ -35,9 +35,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
......@@ -85,8 +85,8 @@ struct WarpGemmImpl
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment