"...composable_kernel_rocm.git" did not exist on "ba6f79a75e65610871fd5139311817642292085c"
Commit 6cb91035 authored by letaoqin's avatar letaoqin
Browse files

add fp16 to test

parent 4dd77195
......@@ -13,6 +13,22 @@
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::fp16_t;
using GDataType = ck_tile::fp16_t;
using DDataType = ck_tile::fp16_t;
using AccDataType = float;
using ODataType = ck_tile::fp16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
......
......@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
typename Ts_::WarpPerBlock_1,
typename Ts_::WarpTile_1>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
......
......@@ -49,7 +49,7 @@ struct fmoe_ // traits, ugly name, only used for internal
;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
......
......@@ -8,7 +8,7 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -252,8 +252,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host);
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
// ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
......@@ -287,7 +287,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl;
......@@ -431,7 +431,7 @@ int main(int argc, char* argv[])
// no dynamic quant case
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
......
......@@ -88,6 +88,32 @@ struct FusedMoeGemmPipeline_General
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
template <typename T>
CK_TILE_HOST_DEVICE static void PrintMem(T& tensor)
{
constexpr auto spans = T::get_distributed_spans();
int counter = 0;
sweep_tile_span(spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx =
get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f"
" \n",
row,
col,
counter,
ck_tile::type_convert<float>(tensor(i_j_idx)));
counter = counter + 1;
}
});
});
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
......@@ -131,56 +157,13 @@ struct FusedMoeGemmPipeline_General
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
{
// check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idk_0 = idxk.impl_.at(0);
printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n",
idm_0,
idk_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
}
PrintMem(a_dram_block);
#endif
auto g_dram_block = load_tile(g_global_to_dram_window);
#if 0
{
constexpr auto g_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(g_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(g_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f"
" \n",
row,
col,
counter,
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
}
});
});
}
PrintMem(g_dram_block);
#endif
clear_tile(s_acc); // initialize C
......@@ -215,32 +198,8 @@ struct FusedMoeGemmPipeline_General
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout(activation, s_acc, s_acc);
#if 1
{
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0;
// a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxm, idxn);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in c row is %d , col is %d, counter is %d, value is: "
"%f \n",
row,
col,
counter,
ck_tile::type_convert<float>(s_acc(i_j_idx)));
}
});
});
}
#if 0
PrintMem(s_acc);
#endif
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
......@@ -249,15 +208,30 @@ struct FusedMoeGemmPipeline_General
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// //y_pre.get_thread_buffer()(i) = type_convert<YDataType>(s_acc.get_thread_buffer()[i]);
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("soure value: %f to value: %f\n",
// s_acc.get_thread_buffer()[i],
// type_convert<float>(y_pre.get_thread_buffer()[i]));
// }
// });
#if 1
PrintMem(y_pre);
#endif
// save to lds
store_tile(bridge_slds_win, y_pre);
block_sync_lds();
// gemm down
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
using SaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = SaccBlockTileType{};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
// y data
auto bridge_llds_win =
make_tile_window(bridge_lds_view,
......@@ -265,6 +239,7 @@ struct FusedMoeGemmPipeline_General
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win);
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
......@@ -278,6 +253,7 @@ struct FusedMoeGemmPipeline_General
index_t iCounter1 = n1_loops - 1;
while(iCounter1 > 0)
{
clear_tile(o_acc);
block_sync_lds();
gemm_1(o_acc, y, d);
block_sync_lds();
......@@ -292,9 +268,16 @@ struct FusedMoeGemmPipeline_General
}
// tail
{
clear_tile(o_acc);
block_sync_lds();
gemm_1(o_acc, y, d);
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
}
#if 0
PrintMem(o_acc);
#endif
// store_tile(o_window_, a_dram_block);
}
};
......
......@@ -13,7 +13,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
namespace ck_tile {
......@@ -230,7 +230,28 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename S_::WarpPerBlock_1,
decltype(warp_gemm)>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
sequence<1>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
template <typename Problem>
......@@ -240,12 +261,12 @@ struct FusedMoeGemmPipelineGeneralPolicy
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto d_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1>,
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
sequence<0, 1>>{};
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
......@@ -326,7 +347,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
......@@ -358,7 +387,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
......@@ -383,27 +420,5 @@ struct FusedMoeGemmPipelineGeneralPolicy
2>>{};
}
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_N1>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV2
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
// constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
// remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "A distribution is wrong!");
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
// remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment