Unverified Commit f1e53807 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into ck_host_lib

parents 7450417d d9f1ead3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct fused_moesorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
};
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
{
};
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moe.hpp"
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
{
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
auto o_data_bytes = [&]() {
if(t.prec_o == "fp32")
return 4;
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
return 2;
else if(t.prec_o == "int8" || t.prec_o == "fp8")
return 1;
return 1;
}();
auto t0 = fused_moesorting_trait{"int32", "fp32"};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
};
auto t1 = fused_moegemm_traits{t.prec_i,
t.prec_w,
t.prec_o,
t.prec_st,
t.prec_sw,
t.prec_sq,
t.prec_kw,
t.block_m,
t.activation,
t.gate_only,
t.fused_quant};
auto a1 = fused_moegemm_args{
a.a_ptr, // const void* a_ptr;
a.a_scale_ptr, // const void* a_scale_ptr;
a.g_ptr, // const void* g_ptr;
a.d_ptr, // const void* d_ptr;
a.g_scale_ptr, // const void* g_scale_ptr;
a.d_scale_ptr, // const void* d_scale_ptr;
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
a.o_ptr, // void* o_ptr;
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
a.hidden_size, // index_t hidden_size;
a.intermediate_size, // index_t intermediate_size;
a.num_tokens, // index_t num_tokens;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.stride_token // index_t stride_token;
};
float r0 = -1;
float r1 = -1;
float r = ck_tile::launch_kernel(
s,
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
// keep unsupported case return negative
if(r0 < 0 || r1 < 0)
return -1;
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
constexpr auto get_activation_ = []() {
if constexpr(Ts_::Activation == 0)
{
return ck_tile::element_wise::FastGeluAsm{};
}
else
return ck_tile::element_wise::Silu{};
};
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
f_act_, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
static int printed = 0;
auto kargs = f_kernel::MakeKargs(a);
if(s.log_level_ > 0 && printed == 0)
{
std::cout << ", " << f_kernel::GetName() << std::flush;
printed = 1;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename I,
typename W,
typename O,
typename ST,
typename SW,
typename SQ,
typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
{
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}
#include <algorithm>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
#include "ck_tile/host.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision")
.insert("prec_w", "bf16", "weight precision")
.insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance",
"0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init",
"1",
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed");
// w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
auto api_str = [&]() {
if(api == 0)
return std::string("fmoe");
else if(api == 1)
return std::string("moeg");
else if(api == 2)
return std::string("moes");
return std::string("");
}();
auto stride_str = [&]() {
if(stride == hidden_size)
return std::string("");
else
return std::string(", st:") + std::to_string(stride);
}();
std::cout
<< "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType;
using DScaleDataType = typename TypeConfig::DScaleDataType;
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
using IndexDataType = typename TypeConfig::IndexDataType;
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
{(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
if(init == 0)
{
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
}
else if(init == 1)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
topk_weight_host);
}
else if(init == 2)
{
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
}
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
{
int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
if(e_cnt >= experts)
e_cnt = 0;
}
}
else
{
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
// leave it here for future debug purpose
#if 0
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
#endif
#if 0
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
#endif
auto cal_tflops = [&](auto ms) {
double flop_gemm_0 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
double flop_gemm_1 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
};
// TODO: this method we use expert-by-expert view, just for reference
auto cal_tbps = [&](auto ms) {
double token_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
sizeof(GDataType);
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
sizeof(DDataType);
double o_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
// ignore index, they are too small
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
(static_cast<double>(ms) * 1e-3) / 1e12;
};
if(api == 0)
{
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
ck_tile::DeviceMem topk_weight_buf(topk_weight_host);
ck_tile::DeviceMem sorted_token_ids_buf(
sorted_token_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_buf(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem num_sorted_tiles_buf(
num_sorted_tiles_host.get_element_space_size_in_bytes());
fused_moe_traits traits{prec_i,
prec_w,
prec_o,
prec_st,
prec_sw,
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
block_m,
hidden_size,
intermediate_size / tp,
tokens,
experts,
topk,
stride};
float ave_time = fused_moe(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
return pass;
}
else if(api == 1)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
// manually clear output buffer for atomic
o_buf.SetZero();
//
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
fused_moegemm_traits traits{prec_i,
prec_w,
prec_o,
prec_st,
prec_sw,
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
intermediate_size / tp,
tokens,
experts,
topk,
stride};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
if(do_validation)
{
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
return pass;
}
return false;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
// no dynamic quant case
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
{
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3;
}
add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
# Batched GEMM
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable `build/bin/tile_example_batched_gemm`
## example
```
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template <typename DataType>
struct BatchedGemmTypeConfig;
template <>
struct BatchedGemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
using Types = BatchedGemmTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride")
.insert("batch_stride_c", "32768", "Batch C stride")
.insert("batch_count", "16", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_C = batch_stride_C;
args.batch_count = batch_count;
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Batched Gemm"};
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * N * K +
sizeof(CDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_batched_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_batched_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
// if(a_layout == "R" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
if(a_layout == "R" && b_layout == "C")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
# Grouped CShuffle GEMM
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-a_layout Tensor A layout (default:R)
-b_layout Tensor B layout (default:R)
-c_layout Tensor C layout (default:R)
-v 0. No validation, 1. Validation on CPU
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment