Commit d0c80b12 authored by shengnxu's avatar shengnxu
Browse files

fix more issues, current status, inline asm using more register than available

parent a759277d
......@@ -37,3 +37,4 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format on
return r;
}
......@@ -5,11 +5,26 @@
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "fused_moegemm_api.cpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
template<typename dtype, typename problem>
struct PipelineDispatch;
template<typename problem> struct PipelineDispatch<ck_tile::int8_t , problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk_int8<problem>;
};
template<typename problem> struct PipelineDispatch<ck_tile::bf16_t, problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk<problem>;
};
template<typename problem> struct PipelineDispatch<ck_tile::fp16_t, problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk<problem>;
};
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
......@@ -38,8 +53,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk_int8<
>;
using f_pipeline = typename PipelineDispatch<typename Ts_::ADataType, f_problem>::type;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 256, 256>, S<1, 4, 1>, S<16, 16, 64>, 1, 1>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType;
// using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType;
......@@ -218,12 +218,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
if (fused_quant == 1)
{
// if (fused_quant == 1)
// {
ck_tile::HostTensor<AScaleDataType> sa_host({tokens, topk});
} else{
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
}
// } else{
// ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
// }
ck_tile::HostTensor<GScaleDataType> sg_host({experts, shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({experts, shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({experts, shared_intermediate_size_1}); // smooth-quant
......@@ -425,7 +425,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts,
block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
ck_tile::reference_fused_moe<float, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
......@@ -535,7 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
ck_tile::reference_fused_moe<float, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
......@@ -555,7 +555,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
gate_only,
1);
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
......@@ -604,6 +605,13 @@ int main(int argc, char* argv[])
? 0
: -2;
}
else if(prec_i == "int8" && prec_w == "int8" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3;
}
......@@ -107,10 +107,9 @@ void reference_fused_moe(
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_weight_idx;
ck_tile::index_t i_weight_idx = i_token >> 24;
if(fquant == 1)
{
i_weight_idx = i_token >> 24;
i_token = i_token & 0xffffff;
}
if (i_token >= tokens)
......
......@@ -245,9 +245,9 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
template <typename AToken_id, typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()( index_t row_ids_a_,
operator()( const AToken_id& row_ids_a_,
const AQRes& res_aq,
const DQRes& res_dq,
const GQRes& res_gq,
......@@ -263,6 +263,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
static_assert(AToken_id::size() == Repeat_M);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
......@@ -371,6 +372,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
register int v_z62 asm("v190") = 0;
register int v_z63 asm("v191") = 0;
index_t temp0 = static_cast<index_t>(row_ids_a_[number<0>{}]);
index_t temp1 = static_cast<index_t>(row_ids_a_[number<1>{}]);
// B nr->kr
#pragma clang diagnostic push
......@@ -397,7 +400,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// [v_acc_13]"+v"(v_acc[13]),
// [v_acc_14]"+v"(v_acc[14]),
// [v_acc_15]"+v"(v_acc[15]),
[v_token_id]"+v"(row_ids_a_),
[v_token_id0]"+v"(temp0),
[v_token_id1]"+v"(temp1),
[s_mem_]"+r"(smem)
: [s_res_aq0]"s"(res_aq[0]),
[s_res_aq1]"s"(res_aq[1]),
......
......@@ -98,7 +98,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
index_t tile_offset_half_b, //splited load alone K in to 2 part
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
......@@ -238,11 +238,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
......@@ -393,11 +388,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
......
......@@ -78,18 +78,17 @@
" v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n"
";---------------------------------------------- \n"
" v_mov_b32 %[v_token_id], %[v_token_id] \n"
" v_lshrrev_b32 v54, 24, %[v_token_id] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id] \n"
" v_add_u32 %[v_token_id], v54, v55 \n"
" v_lshrrev_b32 v54, 24, v7 \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, v7 \n"
" v_add_u32 v7, v54, v55 \n"
" v_lshlrev_b32 %[v_token_id], 2, %[v_token_id] \n"
" v_lshlrev_b32 v7, 2, v7 \n"
" buffer_load_dword v14, %[v_token_id], s[28:31], 0 offen \n"
" v_lshrrev_b32 v54, 24, %[v_token_id0] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id0] \n"
" v_add_u32 v6, v54, v55 \n"
" v_lshrrev_b32 v54, 24, %[v_token_id1] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id1] \n"
" v_add_u32 v7, v54, v55 \n"
" v_lshlrev_b32 v6, 2, v6 \n"
" v_lshlrev_b32 v7, 2, v7 \n"
" buffer_load_dword v14, v6, s[28:31], 0 offen \n"
" buffer_load_dword v15, v7, s[28:31], 0 offen \n"
" buffer_load_dword v16, v10, s[32:35], 0 offen \n"
" buffer_load_dword v17, v11, s[32:35], 0 offen \n"
......
......@@ -804,6 +804,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
{
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 64)
{
return Flatmm_32x512x256_1x4x1_16x16x64_int8{};
}
}
template <typename Problem>
......@@ -851,6 +858,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Block_M1 == 32 && S_::Block_N1 == 256 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 64 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x256x512_1x4x1_16x16x64_int8{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else
{
return FlatmmSn_32x256x512_1x4x1_16x16x64_int8{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
}
};
} // namespace ck_tile
......@@ -116,7 +116,19 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return coords;
}
CK_TILE_DEVICE auto GetRowCoords_A_mma(index_t base_offset)
{
// constexpr index_t KLans = 2;
constexpr index_t MLans = 16;
constexpr index_t MRepeat = BlockShape::Repeat_M1;
auto base_coord = threadIdx.x % MLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
{
......@@ -178,7 +190,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
/////////////
index_t a_scale_expert_stride_0 = kargs.hidden_size;
index_t g_scale_expert_stride_0 = shared_intermediate_size_0;
index_t smq_scale_expert_stride_0 = shared_intermediate_size_0;
index_t d_scale_expert_stride_1 = kargs.hidden_size;
......@@ -192,13 +203,25 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_coords_a_mma = GetRowCoords_A_mma(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = row_ids_a & 0xffffff;
auto row_ids_a_mma = GetRowID(
row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = generate_tuple(
[&](auto i) {
return (row_ids_a[i]) &0xffffff;
},
number<row_ids_a.size()>{});
// auto token_id_mma = generate_tuple(
// [&](auto i) {
// return (row_ids_a_mma[i]) &0xffffff;
// },
// number<row_ids_a_mma.size()>{});
//addr in fact
auto a_coords = generate_tuple(
[&](auto i) {
return (token_id[i]) * kargs.stride_token +
return (row_ids_a[i]) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
......@@ -208,7 +231,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//////aq
auto aq_win = [&]() {
const AScaleDataType* aq_ptr = reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr);
auto aq_view_ = make_naive_tensor_view<address_space_enum::global>(
auto aq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.num_tokens * kargs.topk),
number<1>{});
......@@ -249,7 +272,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto gq_view_ = make_naive_tensor_view<address_space_enum::global>(
auto gq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
gq_ptr,
make_tuple(shared_intermediate_size_1),
number<1>{});
......@@ -264,7 +287,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
auto smq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
smq_ptr,
make_tuple(shared_intermediate_size_1),
number<1>{});
......@@ -303,7 +326,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
auto g_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
g_ptr,
make_tuple(kargs.hidden_size),
number<1>{});
......@@ -323,12 +346,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto d_coords = [&]() {
constexpr index_t Nr_ = 4;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 2;//no more need in int8, method changed, this will be handed in res_s
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
// constexpr index_t Kr0_ = 2;//no more need in int8, method changed, this will be handed in res_s
// constexpr index_t Kr1_ = 4;
// constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 16;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
// constexpr index_t W_ = Kl_ * Nl_ * Kv_;
//constexpr index_t num_offsets_ = Nr_ * Kr0_;
constexpr index_t num_offsets_ = Nr_ ;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
......@@ -351,18 +374,18 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number<row_ids_a.size()>{});
auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id, kargs.num_tokens); },
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{});
auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem), desc_),
desc_.get_lengths(),
{0, 0},
dist_);
}();
// auto bridge_sst_win = [&]() {
// constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
// constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
// return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
// reinterpret_cast<YDataType*>(smem), desc_),
// desc_.get_lengths(),
// {0, 0},
// dist_);
// }();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
......@@ -372,16 +395,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0= uk_0(
row_ids_a,//fake token id, 2D index for X scale
// auto acc_0= uk_0(
uk_0(
row_ids_a_mma,//fake token id, 2D index for X scale
aq_res,
dq_res,
gq_res,
gq_res,
dq_res,
smq_res,
a_res,
a_coords,
g_res,
g_coords,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
......@@ -415,7 +439,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
kargs.hidden_size, // total n number
w_scale,
BlockShape::Block_N1,
shared_intermediate_size_1 * Block_N1 - kr_1 * BlockShape::Block_W1, // along N
shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N
kr_1 * BlockShape::Block_W1,
BlockShape::Block_N1); // along N
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment