Commit d0c80b12 authored by shengnxu's avatar shengnxu
Browse files

fix more issues, current status, inline asm using more register than available

parent a759277d
...@@ -37,3 +37,4 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -37,3 +37,4 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format on // clang-format on
return r; return r;
} }
...@@ -5,11 +5,26 @@ ...@@ -5,11 +5,26 @@
#include "fused_moegemm_api_traits.hpp" #include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp" #include "ck_tile/ops/fused_moe.hpp"
#include "fused_moegemm_api.cpp"
#include <iostream> #include <iostream>
template <ck_tile::index_t... Is> template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>; using S = ck_tile::sequence<Is...>;
template<typename dtype, typename problem>
struct PipelineDispatch;
template<typename problem> struct PipelineDispatch<ck_tile::int8_t , problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk_int8<problem>;
};
template<typename problem> struct PipelineDispatch<ck_tile::bf16_t, problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk<problem>;
};
template<typename problem> struct PipelineDispatch<ck_tile::fp16_t, problem> {
using type = ck_tile::FusedMoeGemmPipeline_FlatmmUk<problem>;
};
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j // do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_> template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...@@ -38,8 +53,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -38,8 +53,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>; f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk_int8<
>; using f_pipeline = typename PipelineDispatch<typename Ts_::ADataType, f_problem>::type;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>; using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 256, 256>, S<1, 4, 1>, S<16, 16, 64>, 1, 1>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
...@@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType; using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType; using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType; // using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType; using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType; using GScaleDataType = typename TypeConfig::GScaleDataType;
...@@ -218,12 +218,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -218,12 +218,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1}); ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
if (fused_quant == 1) // if (fused_quant == 1)
{ // {
ck_tile::HostTensor<AScaleDataType> sa_host({tokens, topk}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens, topk});
} else{ // } else{
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); // ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
} // }
ck_tile::HostTensor<GScaleDataType> sg_host({experts, shared_intermediate_size_0}); ck_tile::HostTensor<GScaleDataType> sg_host({experts, shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({experts, shared_intermediate_size_1}); ck_tile::HostTensor<DScaleDataType> sd_host({experts, shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({experts, shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({experts, shared_intermediate_size_1}); // smooth-quant
...@@ -425,7 +425,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -425,7 +425,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts, experts,
block_m); block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<float, ck_tile::element_wise::Gelu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
...@@ -535,7 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -535,7 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<float, ck_tile::element_wise::Gelu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
...@@ -555,7 +555,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -555,7 +555,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size, hidden_size,
shared_intermediate_size_0, shared_intermediate_size_0,
topk, topk,
gate_only); gate_only,
1);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
...@@ -604,6 +605,13 @@ int main(int argc, char* argv[]) ...@@ -604,6 +605,13 @@ int main(int argc, char* argv[])
? 0 ? 0
: -2; : -2;
} }
else if(prec_i == "int8" && prec_w == "int8" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3; return -3;
} }
...@@ -107,10 +107,9 @@ void reference_fused_moe( ...@@ -107,10 +107,9 @@ void reference_fused_moe(
return; return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile]; ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten]; ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_weight_idx; ck_tile::index_t i_weight_idx = i_token >> 24;
if(fquant == 1) if(fquant == 1)
{ {
i_weight_idx = i_token >> 24;
i_token = i_token & 0xffffff; i_token = i_token & 0xffffff;
} }
if (i_token >= tokens) if (i_token >= tokens)
......
...@@ -245,9 +245,9 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -245,9 +245,9 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear! // TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function! // TODO: need call init_raw() before call this function!
template <typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords> template <typename AToken_id, typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
operator()( index_t row_ids_a_, operator()( const AToken_id& row_ids_a_,
const AQRes& res_aq, const AQRes& res_aq,
const DQRes& res_dq, const DQRes& res_dq,
const GQRes& res_gq, const GQRes& res_gq,
...@@ -263,6 +263,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -263,6 +263,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{ {
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N); static_assert(BCoords::size() == Repeat_N);
static_assert(AToken_id::size() == Repeat_M);
auto a_sst = make_tile_window( auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>( make_tensor_view<address_space_enum::lds>(
...@@ -371,6 +372,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -371,6 +372,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
register int v_z62 asm("v190") = 0; register int v_z62 asm("v190") = 0;
register int v_z63 asm("v191") = 0; register int v_z63 asm("v191") = 0;
index_t temp0 = static_cast<index_t>(row_ids_a_[number<0>{}]);
index_t temp1 = static_cast<index_t>(row_ids_a_[number<1>{}]);
// B nr->kr // B nr->kr
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -397,7 +400,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -397,7 +400,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// [v_acc_13]"+v"(v_acc[13]), // [v_acc_13]"+v"(v_acc[13]),
// [v_acc_14]"+v"(v_acc[14]), // [v_acc_14]"+v"(v_acc[14]),
// [v_acc_15]"+v"(v_acc[15]), // [v_acc_15]"+v"(v_acc[15]),
[v_token_id]"+v"(row_ids_a_), [v_token_id0]"+v"(temp0),
[v_token_id1]"+v"(temp1),
[s_mem_]"+r"(smem) [s_mem_]"+r"(smem)
: [s_res_aq0]"s"(res_aq[0]), : [s_res_aq0]"s"(res_aq[0]),
[s_res_aq1]"s"(res_aq[1]), [s_res_aq1]"s"(res_aq[1]),
......
...@@ -98,7 +98,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -98,7 +98,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
index_t tile_offset_half_b, //splited load alone K in to 2 part index_t tile_offset_half_b, //splited load alone K in to 2 part
index_t tile_offset_o) index_t tile_offset_o)
{ {
static_assert(BCoords::size() == 8); // 8 static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8); static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType); const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
...@@ -238,11 +238,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -238,11 +238,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
...@@ -393,11 +388,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -393,11 +388,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
......
...@@ -78,18 +78,17 @@ ...@@ -78,18 +78,17 @@
" v_mov_b32 v53, 0x00007fff \n" " v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n" " s_waitcnt 0x0000 \n"
";---------------------------------------------- \n" ";---------------------------------------------- \n"
" v_mov_b32 %[v_token_id], %[v_token_id] \n" " v_lshrrev_b32 v54, 24, %[v_token_id0] \n"
" v_lshrrev_b32 v54, 24, %[v_token_id] \n" " v_mul_i32_i24 v54, s66, v54 \n"
" v_mul_i32_i24 v54, s66, v54 \n" " v_and_b32 v55, 0x00ffffff, %[v_token_id0] \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id] \n" " v_add_u32 v6, v54, v55 \n"
" v_add_u32 %[v_token_id], v54, v55 \n" " v_lshrrev_b32 v54, 24, %[v_token_id1] \n"
" v_lshrrev_b32 v54, 24, v7 \n" " v_mul_i32_i24 v54, s66, v54 \n"
" v_mul_i32_i24 v54, s66, v54 \n" " v_and_b32 v55, 0x00ffffff, %[v_token_id1] \n"
" v_and_b32 v55, 0x00ffffff, v7 \n" " v_add_u32 v7, v54, v55 \n"
" v_add_u32 v7, v54, v55 \n" " v_lshlrev_b32 v6, 2, v6 \n"
" v_lshlrev_b32 %[v_token_id], 2, %[v_token_id] \n" " v_lshlrev_b32 v7, 2, v7 \n"
" v_lshlrev_b32 v7, 2, v7 \n" " buffer_load_dword v14, v6, s[28:31], 0 offen \n"
" buffer_load_dword v14, %[v_token_id], s[28:31], 0 offen \n"
" buffer_load_dword v15, v7, s[28:31], 0 offen \n" " buffer_load_dword v15, v7, s[28:31], 0 offen \n"
" buffer_load_dword v16, v10, s[32:35], 0 offen \n" " buffer_load_dword v16, v10, s[32:35], 0 offen \n"
" buffer_load_dword v17, v11, s[32:35], 0 offen \n" " buffer_load_dword v17, v11, s[32:35], 0 offen \n"
......
...@@ -804,6 +804,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -804,6 +804,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
{ {
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{}; return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 64)
{
return Flatmm_32x512x256_1x4x1_16x16x64_int8{};
}
} }
template <typename Problem> template <typename Problem>
...@@ -851,6 +858,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -851,6 +858,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{}; // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
} }
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Block_M1 == 32 && S_::Block_N1 == 256 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 64 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x256x512_1x4x1_16x16x64_int8{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else
{
return FlatmmSn_32x256x512_1x4x1_16x16x64_int8{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -116,7 +116,19 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -116,7 +116,19 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return coords; return coords;
} }
CK_TILE_DEVICE auto GetRowCoords_A_mma(index_t base_offset)
{
// constexpr index_t KLans = 2;
constexpr index_t MLans = 16;
constexpr index_t MRepeat = BlockShape::Repeat_M1;
auto base_coord = threadIdx.x % MLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS> template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr) CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
{ {
...@@ -178,7 +190,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -178,7 +190,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
///////////// /////////////
index_t a_scale_expert_stride_0 = kargs.hidden_size;
index_t g_scale_expert_stride_0 = shared_intermediate_size_0; index_t g_scale_expert_stride_0 = shared_intermediate_size_0;
index_t smq_scale_expert_stride_0 = shared_intermediate_size_0; index_t smq_scale_expert_stride_0 = shared_intermediate_size_0;
index_t d_scale_expert_stride_1 = kargs.hidden_size; index_t d_scale_expert_stride_1 = kargs.hidden_size;
...@@ -192,13 +203,25 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -192,13 +203,25 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W) BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0); auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_coords_a_mma = GetRowCoords_A_mma(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID( auto row_ids_a = GetRowID(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr)); row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = row_ids_a & 0xffffff; auto row_ids_a_mma = GetRowID(
row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = generate_tuple(
[&](auto i) {
return (row_ids_a[i]) &0xffffff;
},
number<row_ids_a.size()>{});
// auto token_id_mma = generate_tuple(
// [&](auto i) {
// return (row_ids_a_mma[i]) &0xffffff;
// },
// number<row_ids_a_mma.size()>{});
//addr in fact //addr in fact
auto a_coords = generate_tuple( auto a_coords = generate_tuple(
[&](auto i) { [&](auto i) {
return (token_id[i]) * kargs.stride_token + return (row_ids_a[i]) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
...@@ -208,7 +231,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -208,7 +231,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//////aq //////aq
auto aq_win = [&]() { auto aq_win = [&]() {
const AScaleDataType* aq_ptr = reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr); const AScaleDataType* aq_ptr = reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr);
auto aq_view_ = make_naive_tensor_view<address_space_enum::global>( auto aq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
aq_ptr, aq_ptr,
make_tuple(kargs.num_tokens * kargs.topk), make_tuple(kargs.num_tokens * kargs.topk),
number<1>{}); number<1>{});
...@@ -249,7 +272,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -249,7 +272,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 + static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0; intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline // const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto gq_view_ = make_naive_tensor_view<address_space_enum::global>( auto gq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
gq_ptr, gq_ptr,
make_tuple(shared_intermediate_size_1), make_tuple(shared_intermediate_size_1),
number<1>{}); number<1>{});
...@@ -264,7 +287,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -264,7 +287,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 + static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0; intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline // const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>( auto smq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
smq_ptr, smq_ptr,
make_tuple(shared_intermediate_size_1), make_tuple(shared_intermediate_size_1),
number<1>{}); number<1>{});
...@@ -303,7 +326,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -303,7 +326,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) + const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1; static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx // const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto g_view_ = make_naive_tensor_view<address_space_enum::global>( auto g_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(kargs.hidden_size), make_tuple(kargs.hidden_size),
number<1>{}); number<1>{});
...@@ -323,12 +346,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -323,12 +346,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto d_coords = [&]() { auto d_coords = [&]() {
constexpr index_t Nr_ = 4; constexpr index_t Nr_ = 4;
constexpr index_t Nw_ = 4; constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 2;//no more need in int8, method changed, this will be handed in res_s // constexpr index_t Kr0_ = 2;//no more need in int8, method changed, this will be handed in res_s
constexpr index_t Kr1_ = 4; // constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4; // constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16; constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 16; constexpr index_t Kv_ = 16;
constexpr index_t W_ = Kl_ * Nl_ * Kv_; // constexpr index_t W_ = Kl_ * Nl_ * Kv_;
//constexpr index_t num_offsets_ = Nr_ * Kr0_; //constexpr index_t num_offsets_ = Nr_ * Kr0_;
constexpr index_t num_offsets_ = Nr_ ; constexpr index_t num_offsets_ = Nr_ ;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
...@@ -351,18 +374,18 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -351,18 +374,18 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto o_flags = auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id, kargs.num_tokens); }, generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto bridge_sst_win = [&]() { // auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>(); // constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist(); // constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>( // return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem), desc_), // reinterpret_cast<YDataType*>(smem), desc_),
desc_.get_lengths(), // desc_.get_lengths(),
{0, 0}, // {0, 0},
dist_); // dist_);
}(); // }();
auto o_res = auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr), make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType)); kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
...@@ -372,16 +395,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -372,16 +395,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0= uk_0( // auto acc_0= uk_0(
row_ids_a,//fake token id, 2D index for X scale uk_0(
row_ids_a_mma,//fake token id, 2D index for X scale
aq_res, aq_res,
dq_res,
gq_res, gq_res,
gq_res, smq_res,
dq_res,
a_res, a_res,
a_coords, a_coords,
g_res, g_res,
g_coords, g_coords,
smem, smem,
kargs.hidden_size, kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
...@@ -415,7 +439,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -415,7 +439,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
kargs.hidden_size, // total n number kargs.hidden_size, // total n number
w_scale, w_scale,
BlockShape::Block_N1, BlockShape::Block_N1,
shared_intermediate_size_1 * Block_N1 - kr_1 * BlockShape::Block_W1, // along N shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N
kr_1 * BlockShape::Block_W1, kr_1 * BlockShape::Block_W1,
BlockShape::Block_N1); // along N BlockShape::Block_N1); // along N
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment