"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "3986e2d2cfadd43d9bb5fbac5ef711f902c06831"
Commit 84755f74 authored by “letaoqin”'s avatar “letaoqin”
Browse files

format

parent eab497e8
...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -40,7 +40,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -40,7 +40,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmGl<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmGl<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>; using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a); const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize(); constexpr dim3 blocks = f_kernel::BlockSize();
......
...@@ -48,7 +48,7 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -48,7 +48,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
; ;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -432,14 +432,18 @@ struct tile_window_linear ...@@ -432,14 +432,18 @@ struct tile_window_linear
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>) CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{ {
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{}); constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
constexpr auto is_pure_linear_tensor = reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{}); constexpr auto is_pure_linear_tensor =
if constexpr (is_pure_linear_tensor) { reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
if constexpr(is_pure_linear_tensor)
{
// this case usually is a LDS window, everything is build time know. // this case usually is a LDS window, everything is build time know.
// we directly use BottomTensorView to compute the offset, in case there is any padding // we directly use BottomTensorView to compute the offset, in case there is any padding
auto bottom_tensor_coord = make_tensor_coordinate( auto bottom_tensor_coord =
BottomTensorView{}.get_tensor_descriptor(), linear_coord); make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
return bottom_tensor_coord.get_offset(); return bottom_tensor_coord.get_offset();
} else { }
else
{
// this case usually is a global window, where last dim can be linear // this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset // we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense // ... hoping that there is no extra padding between other dims, which make sense
......
...@@ -135,7 +135,7 @@ void reference_fused_moe( ...@@ -135,7 +135,7 @@ void reference_fused_moe(
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++) for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{ {
Activation{}(y(0, i_n), acc_0(0, i_n)); Activation{}(y(0, i_n), acc_0(0, i_n));
//printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n)); // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
} }
} }
else else
......
...@@ -620,8 +620,8 @@ struct FastGeluAsm ...@@ -620,8 +620,8 @@ struct FastGeluAsm
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{ {
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f); // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f; const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f; const float c2 = -2.0 * 0.797885f;
const float u0 = x.x * (c1 * x.x * x.x + c2); const float u0 = x.x * (c1 * x.x * x.x + c2);
const float emu0 = exp(u0); const float emu0 = exp(u0);
y.x = x.x / (1.f + emu0); y.x = x.x / (1.f + emu0);
...@@ -641,25 +641,27 @@ struct FastGeluAsm ...@@ -641,25 +641,27 @@ struct FastGeluAsm
float tmp0, tmp1; float tmp0, tmp1;
float y0, y1; float y0, y1;
asm volatile("v_mul_f32 %[v_tmp0], %[v_x0], %[v_x0] ; x*x\n" asm volatile(
"v_mul_f32 %[v_tmp1], %[v_x1], %[v_x1] ; x*x\n" "v_mul_f32 %[v_tmp0], %[v_x0], %[v_x0] ; x*x\n"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n" "v_mul_f32 %[v_tmp1], %[v_x1], %[v_x1] ; x*x\n"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n" "v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_x0] ; x*(c1*x*x+c2)\n" "v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_x1] ; x*(c1*x*x+c2)\n" "v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_x0] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n" "v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_x1] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n" "v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n" "v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n" "v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n" "v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n" "v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n" "v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n" "v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_x0] ; x * 1/(emu+1f)\n" "v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_x1] ; x * 1/(emu+1f)\n" "v_mul_f32 %[v_y0], %[v_tmp0], %[v_x0] ; x * 1/(emu+1f)\n"
: [v_y0] "=v"(y0), [v_y1] "=v"(y1), [v_tmp0] "+v"(tmp0), [v_tmp1] "+v"(tmp1) "v_mul_f32 %[v_y1], %[v_tmp1], %[v_x1] ; x * 1/(emu+1f)\n"
: [v_x0] "v"(x.x), [v_x1] "v"(x.y), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_) : [v_y0] "=v"(y0), [v_y1] "=v"(y1), [v_tmp0] "+v"(tmp0), [v_tmp1] "+v"(tmp1)
:); :
[v_x0] "v"(x.x), [v_x1] "v"(x.y), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
y.x = y0; y.x = y0;
y.y = y1; y.y = y1;
} }
......
...@@ -72,7 +72,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16 ...@@ -72,7 +72,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
sequence<2, 1>, // !! note here is different sequence<2, 1>, // !! note here is different
sequence<0, 0>>{}; sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
...@@ -82,7 +82,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16 ...@@ -82,7 +82,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
static CK_TILE_DEVICE constexpr auto MakeCBlockTile() static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{ {
using CDataType = float; using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist(); constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
...@@ -180,8 +180,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16 ...@@ -180,8 +180,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{ {
// load from LDS to register, every wave has same layout // load from LDS to register, every wave has same layout
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KPad = KPack_; // pad between warps constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16; constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4; constexpr index_t kABKLane = 4;
...@@ -189,26 +189,25 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16 ...@@ -189,26 +189,25 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
constexpr index_t kKIter = 2; constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter)); static_assert(KPack_ == (kABKPerLane * kKIter));
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lds_block_desc_0 =
make_tuple(number<Repeat_M>{}, // m0 y make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
number<kAMLane>{}, // m1 p number<kAMLane>{}, // m1 p
number<Repeat_K>{}, // k0 y number<Repeat_K>{}, // k0 y
number<kABKLane>{}, // k1 p number<kABKLane>{}, // k1 p
number<KPack_>{}), // k2 y-vector number<KPack_>{}), // k2 y-vector
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0 make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1 number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0 number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1 number<KPack_>{}, // k1
number<1>{}), // k2 number<1>{}), // k2
number<KPack_>{}, // lds load vector number<KPack_>{}, // lds load vector
number<1>{}); number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor( constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0, lds_block_desc_0,
make_tuple(make_merge_transform( make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_tuple(number<Repeat_M>{}, number<kAMLane>{})), make_merge_transform(
make_merge_transform(make_tuple( make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -291,12 +290,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16 ...@@ -291,12 +290,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
}, },
number<a_sld.get_num_of_access()>{}); number<a_sld.get_num_of_access()>{});
// printf("----- tid:%d, a_sld:%d\n", static_cast<index_t>(threadIdx.x), // printf("----- tid:%d, a_sld:%d\n", static_cast<index_t>(threadIdx.x),
// static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset())); // static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset()));
index_t loop_cnt = k / Block_K; index_t loop_cnt = k / Block_K;
// this is the acc thread buffer // this is the acc thread buffer
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace ck_tile {
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
// struct FusedMoeGemmHostArgs
// {
// const void* a_ptr; // [m, k], input token
// const void* a_scale_ptr; // [m, 1], token scale
// const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
// const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
// const void* g_scale_ptr; // [e, 1, n], gate(up) scale
// const void* d_scale_ptr; // [e, 1, k], down scale
// const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
// void* o_ptr; // [m, k], output token
// const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
// const void* sorted_weight_ptr; // [max_num_tokens_padded]
// const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
// const void* num_sorted_tiles_ptr; // [1]
// index_t hidden_size; // k
// index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
// index_t num_tokens; // input number of tokens for current iteration
// index_t num_experts; // number of groups
// index_t topk; // need this?
// index_t stride_token; // for input/output, stride for each row, should >= hidden_size
// };
// This is scatter/gather b2b group-gemm
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmGlKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
using DDataType = typename Pipeline::Problem::DDataType;
using AccDataType = typename Pipeline::Problem::AccDataType;
using ODataType = typename Pipeline::Problem::ODataType;
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = BlockShape;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, GDataType>) {
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
}
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#undef _SS_
#undef _TS_
// clang-format on
}
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using Kargs = FusedMoeGemmKargs;
using Hargs = FusedMoeGemmHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
// TODO: hargs/kargs not guranteed to be the same
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(i_m and i_n)
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
}
};
} // namespace ck_tile
...@@ -70,12 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -70,12 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmGl
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// matrix a or tokens smem
constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// shuffle C matrix
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return smem_bridge;
return max(smem_mat_a, smem_bridge);
} }
template <typename Karg> template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
...@@ -86,7 +90,6 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -86,7 +90,6 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore = smem; ignore = smem;
ignore = sorted_tile_id; ignore = sorted_tile_id;
ignore = intermediate_tile_id; ignore = intermediate_tile_id;
} }
}; };
......
...@@ -590,39 +590,40 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -590,39 +590,40 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc() CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc()
{ {
constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0; constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0; constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0; constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
constexpr index_t kAMLane = 16; constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4; constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4; constexpr index_t kABKPerLane = 4;
constexpr index_t KPack = kABKPerLane; constexpr index_t KPack = kABKPerLane;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<Repeat_M>{}, // m make_tuple(number<Repeat_M>{}, // m
number<Repeat_N>{}, // n number<Repeat_N>{}, // n
number<WarpPerBlock_N>{}, // n number<WarpPerBlock_N>{}, // n
number<kABKLane>{}, // n number<kABKLane>{}, // n
number<kAMLane>{}, // m number<kAMLane>{}, // m
number<KPack>{}), // n number<KPack>{}), // n
make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m
number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n
number<kABKLane * kAMLane * KPack>{}, // n number<kABKLane * kAMLane * KPack>{}, // n
number<kAMLane * KPack>{}, // n number<kAMLane * KPack>{}, // n
number<KPack>{}, // m number<KPack>{}, // m
number<1>{}), // n number<1>{}), // n
number<KPack>{}, // lds store vector(actually no explicit store) number<KPack>{}, // lds store vector(actually no explicit store)
number<1>{}); number<1>{});
constexpr auto desc = transform_tensor_descriptor( constexpr auto desc = transform_tensor_descriptor(
lds_block_desc_0, lds_block_desc_0,
make_tuple( make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})), make_merge_transform(make_tuple(number<Repeat_N>{},
make_merge_transform(make_tuple(number<Repeat_N>{}, number<WarpPerBlock_N>{}, number<kABKLane>{}, number<KPack>{})) number<WarpPerBlock_N>{},
), number<kABKLane>{},
make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}), number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc; return desc;
} }
......
...@@ -342,13 +342,11 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -342,13 +342,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto bridge_sst_win = [&]() { auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>(); constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist(); constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
return make_tile_window_linear( return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
make_tensor_view<address_space_enum::lds>( reinterpret_cast<YDataType*>(smem), desc_),
reinterpret_cast<YDataType*>(smem), desc_.get_lengths(),
desc_), {0, 0},
desc_.get_lengths(), dist_);
{0, 0},
dist_);
}(); }();
auto o_res = auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr), make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
...@@ -442,16 +440,17 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -442,16 +440,17 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
// return ; // return ;
//sweep_tile(acc_0, // sweep_tile(acc_0,
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); }); // [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
sweep_tile(acc_0, sweep_tile(
[&](auto idx0, auto idx1) { acc_0,
fp32x2_t v_ {acc_0(idx0), acc_0(idx1)}; [&](auto idx0, auto idx1) {
typename Problem::GateActivation{}(v_, v_); fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
acc_0(idx0) = v_.x; typename Problem::GateActivation{}(v_, v_);
acc_0(idx1) = v_.y; acc_0(idx0) = v_.x;
}, acc_0(idx1) = v_.y;
sequence<1, 2>{}); },
sequence<1, 2>{});
#if 0 #if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, " printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment