Commit 06914eed authored by carlushuang's avatar carlushuang
Browse files

block-asm

parent b0dd570a
......@@ -12,6 +12,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
......
......@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -33,11 +33,12 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::Gelu, // TODO: hardcoded
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape,
f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
......
......@@ -8,10 +8,7 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
......
......@@ -62,6 +62,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......
......@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
......@@ -180,6 +181,33 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "=s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(ROUND_BIAS_FOR_BF16), [v_bhi] "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
......@@ -213,6 +241,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}
......
......@@ -624,6 +624,40 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
// return [m0_init_value, size_per_issue]
// m0_init_value-> directly use this to set m0 value
// size_per_issue-> direclty use this to inc m0 every issue
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_smem_info(LdsTileWindow_&& lds_tile) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access = -1,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
#if 1
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
#else
#define GET_ASYNC_STORE_SMEM_INFO(lds_tile__) \
[&](auto lds_tile_) { \
using LdsTileWindow = remove_cvref_t<decltype(lds_tile_)>; \
using LdsDataType = typename LdsTileWindow::DataType; \
\
/* issues * warps * lanes */ \
static_assert(LdsTileWindow::get_num_of_dimension() == 3); \
\
const index_t size_per_buf_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType); \
\
const index_t size_per_wave_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<1>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t size_per_issue_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<1>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t m0_init_value_ = size_per_buf_ + size_per_wave_ * get_warp_id(); \
\
return make_tuple(m0_init_value_, size_per_issue_); \
}(lds_tile__)
#endif
} // namespace ck_tile
......@@ -572,6 +572,49 @@ struct FastGelu
}
};
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -8,6 +8,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
......@@ -134,6 +134,7 @@ struct FusedMoeGemmKernel
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool UseUK = true;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
......@@ -210,6 +211,23 @@ struct FusedMoeGemmKernel
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
......@@ -220,7 +238,8 @@ struct FusedMoeGemmKernel
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t kr_1 =
kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
......@@ -233,8 +252,9 @@ struct FusedMoeGemmKernel
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
const IndexDataType expert_id =
__builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
......@@ -243,12 +263,13 @@ struct FusedMoeGemmKernel
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
const auto sorted_token_id =
a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
......@@ -361,7 +382,9 @@ struct FusedMoeGemmKernel
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size);
kargs.intermediate_size,
kargs.stride_token);
}
}
};
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
......@@ -318,6 +319,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0,
......@@ -556,7 +559,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr index_t KPad = 0; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......@@ -573,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......@@ -589,7 +592,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
......@@ -716,7 +719,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
......@@ -812,5 +815,31 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16{};
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmUk
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_flatmm_uk";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
return MRepeat;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
auto base_coord = threadIdx.x / KLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords,
const IndexDataType* sorted_token_ids_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
});
return row_ids;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16
constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M;
auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset;
array<index_t, WarpGemmRepeat_M> coords;
static_for<0, WarpGemmRepeat_M, 1>{}(
[&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
const TopkWeightDataType* sorted_weight_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
return w;
}
CK_TILE_DEVICE auto GetRowCoords_O()
{
constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / NLans;
constexpr index_t MRepeat = BlockShape::Block_M1 / MLans;
auto base_coord = threadIdx.x / NLans;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
/*
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
*/
template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID_A(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
const auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
const auto g_window_ =
make_tile_window_linear(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{},
number<1>{});
const auto d_window_ =
make_tile_window_linear(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
return d_window_;
}();
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
#if 0
auto d_coords = generate_tuple([&](auto i) {
return d_win.cached_coords_[i].get_offset(); },
number<decltype(d_win)::NumAccess_NonLinear>{});
#else
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4;
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
constexpr index_t num_offsets_ = Nr_ * Kr0_;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_;
return generate_tuple(
[&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * kargs.intermediate_size * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
#endif
auto o_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<a_coords.size()>{});
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
}();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
kargs.stride_token,
BlockShape::Block_Kr0 * BlockShape::Block_W0);
sweep_tile(acc_0,
[&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
smem,
kargs.hidden_size,
w_scale,
BlockShape::Block_Kr0 * BlockShape::Block_W0,
kargs.stride_token);
}
};
} // namespace ck_tile
......@@ -18,6 +18,7 @@ enum class WGAttrCtlEnum
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
......@@ -38,6 +39,28 @@ enum class WGAttrCtlEnum
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
......@@ -72,22 +95,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
else
{
#if defined(__gfx9__)
......@@ -147,22 +155,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__)
......@@ -223,22 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
......@@ -324,23 +302,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "a", "v")
}
else
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
......@@ -623,22 +585,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment