Commit 06914eed authored by carlushuang's avatar carlushuang
Browse files

block-asm

parent b0dd570a
...@@ -12,6 +12,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) ...@@ -12,6 +12,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) # list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
......
...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
fused_moegemm_<t_>(s, a); fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -33,11 +33,12 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -33,11 +33,12 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::YSmoothScaleDataType, typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType, typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, typename Ts_::IndexDataType,
ck_tile::element_wise::Gelu, // TODO: hardcoded ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape, f_shape,
f_traits>; f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>; using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
......
...@@ -8,10 +8,7 @@ ...@@ -8,10 +8,7 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......
...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode ...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan, truncate_with_nan,
truncate, truncate,
standard_asm, standard_asm,
rta_asm, // round to nearest away
}; };
template <bf16_rounding_mode rounding = template <bf16_rounding_mode rounding =
...@@ -180,6 +181,33 @@ uint16_t float_to_bf16_rtn_asm(float f) ...@@ -180,6 +181,33 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32); return uint16_t(u.int32);
} }
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "=s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(ROUND_BIAS_FOR_BF16), [v_bhi] "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN // Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f) constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
...@@ -213,6 +241,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round ...@@ -213,6 +241,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f); return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f); return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else else
return float_to_bf16_truc_raw(f); return float_to_bf16_truc_raw(f);
} }
......
...@@ -624,6 +624,40 @@ struct tile_window_linear ...@@ -624,6 +624,40 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE(); WINDOW_DISPATCH_ISSUE();
} }
// return [m0_init_value, size_per_issue]
// m0_init_value-> directly use this to set m0 value
// size_per_issue-> direclty use this to inc m0 every issue
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_smem_info(LdsTileWindow_&& lds_tile) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
index_t i_access = -1, index_t i_access = -1,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
#if 1
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
#else
#define GET_ASYNC_STORE_SMEM_INFO(lds_tile__) \
[&](auto lds_tile_) { \
using LdsTileWindow = remove_cvref_t<decltype(lds_tile_)>; \
using LdsDataType = typename LdsTileWindow::DataType; \
\
/* issues * warps * lanes */ \
static_assert(LdsTileWindow::get_num_of_dimension() == 3); \
\
const index_t size_per_buf_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType); \
\
const index_t size_per_wave_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<1>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t size_per_issue_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<1>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t m0_init_value_ = size_per_buf_ + size_per_wave_ * get_warp_id(); \
\
return make_tuple(m0_init_value_, size_per_issue_); \
}(lds_tile__)
#endif
} // namespace ck_tile
...@@ -572,6 +572,49 @@ struct FastGelu ...@@ -572,6 +572,49 @@ struct FastGelu
} }
}; };
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
};
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2))) // y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu struct Gelu
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
...@@ -133,7 +133,8 @@ struct FusedMoeGemmKernel ...@@ -133,7 +133,8 @@ struct FusedMoeGemmKernel
using IndexDataType = typename Pipeline::Problem::IndexDataType; using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType; using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits; using Traits = typename Pipeline::Problem::Traits;
static constexpr bool UseUK = true;
static constexpr bool IsGateOnly = Traits::IsGateOnly; static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
...@@ -211,157 +212,179 @@ struct FusedMoeGemmKernel ...@@ -211,157 +212,179 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// allocate LDS if constexpr(UseUK)
// __shared__ char smem_ptr[GetSmemSize()]; {
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0; num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0 const auto [sorted_tile_id, intermediate_tile_id] =
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0 Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; return;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; }
else
// note this is in unit of tile, need multiple tile size to get the index {
const auto [sorted_tile_id, intermediate_tile_id] = // allocate LDS
Partitioner{}(num_sorted_tiles, kargs.intermediate_size); // __shared__ char smem_ptr[GetSmemSize()];
if(sorted_tile_id >= num_sorted_tiles) IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
return; *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
// index along intermediate_size index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id * index_t kr_1 =
// BlockShape::Block_N0); kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0); index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0; __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
index_t token_id = // note this is in unit of tile, need multiple tile size to get the index
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]; const auto [sorted_tile_id, intermediate_tile_id] =
auto topk_weight = Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id]; if(sorted_tile_id >= num_sorted_tiles)
return;
const auto a_window = [&]() {
// A is already pre-padded in previous kernel const IndexDataType expert_id =
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr); __builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>( kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size), // index along intermediate_size
make_tuple(kargs.stride_token, 1), // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
number<Pipeline::kAlignmentA>{}, // BlockShape::Block_N0);
number<1>{}); index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view( const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
a_view_, const auto sorted_token_id =
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}), index_t token_id =
make_tuple(sequence<0>{}, sequence<1>{})); reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
const auto a_window_ = make_tile_window( kargs.sorted_weight_ptr)[sorted_token_id];
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), const auto a_window = [&]() {
{0, 0}); // A is already pre-padded in previous kernel
return a_window_; const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
}(); const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
// TODO: gtile using NSub to have less register pressure make_tuple(kargs.num_tokens, kargs.hidden_size),
const auto g_window = [&]() { make_tuple(kargs.stride_token, 1),
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + number<Pipeline::kAlignmentA>{},
static_cast<long_index_t>(expert_id) * expert_stride_0 + number<1>{});
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>( // gather is here use indexing transform
g_ptr, const auto a_gather_view_ = transform_tensor_view(
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}), a_view_,
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1), make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
number<Pipeline::kAlignmentG>{}, make_pass_through_transform(kargs.hidden_size)),
number<1>{}); make_tuple(sequence<0>{}, sequence<1>{}),
const auto g_view_1_ = make_tuple(sequence<0>{}, sequence<1>{}));
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{}, const auto a_window_ = make_tile_window(
number<BlockShape::Block_Kr0>{}, a_gather_view_,
number<BlockShape::Block_W0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{}); {0, 0});
return a_window_;
const auto g_window_ = make_tile_window(g_view_1_, }();
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{}, // TODO: gtile using NSub to have less register pressure
number<BlockShape::Block_W0>{}), const auto g_window = [&]() {
{0, 0, 0}); const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
return g_window_; static_cast<long_index_t>(expert_id) * expert_stride_0 +
}(); interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
const auto d_window = [&]() { g_ptr,
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) + make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
static_cast<long_index_t>(expert_id) * expert_stride_1 + make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
interm_idx_nr * BlockShape::Block_W1; number<Pipeline::kAlignmentG>{},
// note interm_idx_nr is along the gemm-k dim of 2nd gemm number<1>{});
const auto g_view_1_ =
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( pad_tensor_view(g_view_,
d_ptr, make_tuple(number<BlockShape::Block_Nr0>{},
make_tuple(nr_1, kr_1, BlockShape::Block_W1), number<BlockShape::Block_Kr0>{},
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), number<BlockShape::Block_W0>{}),
number<Pipeline::kAlignmentD>{}, sequence<PadIntermediateSize, PadHiddenSize, 0>{});
number<1>{});
const auto d_view_1_ = const auto g_window_ = make_tile_window(g_view_1_,
pad_tensor_view(d_view_, make_tuple(number<BlockShape::Block_Nr0>{},
make_tuple(number<BlockShape::Block_Nr1>{}, number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_Kr1>{}, number<BlockShape::Block_W0>{}),
number<BlockShape::Block_W1>{}), {0, 0, 0});
sequence<PadHiddenSize, PadIntermediateSize, 0>{}); return g_window_;
}();
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{}, const auto d_window = [&]() {
number<BlockShape::Block_Kr1>{}, const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
number<BlockShape::Block_W1>{}), static_cast<long_index_t>(expert_id) * expert_stride_1 +
{0, 0, 0}); interm_idx_nr * BlockShape::Block_W1;
return d_window_; // note interm_idx_nr is along the gemm-k dim of 2nd gemm
}();
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
auto o_window = [&]() { d_ptr,
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr); make_tuple(nr_1, kr_1, BlockShape::Block_W1),
auto o_view_ = make_naive_tensor_view<address_space_enum::global, make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
memory_operation_enum::atomic_add>( number<Pipeline::kAlignmentD>{},
o_ptr, number<1>{});
make_tuple(kargs.num_tokens, kargs.hidden_size), const auto d_view_1_ =
make_tuple(kargs.stride_token, 1), pad_tensor_view(d_view_,
number<Pipeline::kAlignmentO>{}, make_tuple(number<BlockShape::Block_Nr1>{},
number<1>{}); number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
// gather is here sequence<PadHiddenSize, PadIntermediateSize, 0>{});
auto o_scatter_view_ = transform_tensor_view(
o_view_, const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(number<BlockShape::Block_Nr1>{},
make_pass_through_transform(kargs.hidden_size)), number<BlockShape::Block_Kr1>{},
make_tuple(sequence<0>{}, sequence<1>{}), number<BlockShape::Block_W1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); {0, 0, 0});
return d_window_;
auto o_window_ = make_tile_window( }();
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}), auto o_window = [&]() {
{0, 0}); ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
return o_window_; auto o_view_ = make_naive_tensor_view<address_space_enum::global,
}(); memory_operation_enum::atomic_add>(
o_ptr,
// do compute yeah make_tuple(kargs.num_tokens, kargs.hidden_size),
Pipeline{}(a_window, make_tuple(kargs.stride_token, 1),
g_window, number<Pipeline::kAlignmentO>{},
d_window, number<1>{});
o_window,
topk_weight, // gather is here
smem, auto o_scatter_view_ = transform_tensor_view(
kargs.hidden_size, o_view_,
kargs.intermediate_size); make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
}
} }
}; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...@@ -318,6 +319,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -318,6 +319,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ {
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0, return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0, S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0, S_::Repeat_N0, /// hidden_radio_0,
...@@ -556,7 +559,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -556,7 +559,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps constexpr index_t KPad = 0; // pad between warps
constexpr auto desc = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
...@@ -573,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -573,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
...@@ -589,7 +592,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -589,7 +592,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this // A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly // TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav; constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly // TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> && std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
...@@ -716,7 +719,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -716,7 +719,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1() CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav; constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly // TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> && std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
...@@ -812,5 +815,31 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -812,5 +815,31 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr); make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor; return y_block_tensor;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16{};
}
}
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmUk
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_flatmm_uk";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
return MRepeat;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
auto base_coord = threadIdx.x / KLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords,
const IndexDataType* sorted_token_ids_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
});
return row_ids;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16
constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M;
auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset;
array<index_t, WarpGemmRepeat_M> coords;
static_for<0, WarpGemmRepeat_M, 1>{}(
[&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
const TopkWeightDataType* sorted_weight_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
return w;
}
CK_TILE_DEVICE auto GetRowCoords_O()
{
constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / NLans;
constexpr index_t MRepeat = BlockShape::Block_M1 / MLans;
auto base_coord = threadIdx.x / NLans;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
/*
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
*/
template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID_A(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
const auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
const auto g_window_ =
make_tile_window_linear(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{},
number<1>{});
const auto d_window_ =
make_tile_window_linear(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
return d_window_;
}();
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
#if 0
auto d_coords = generate_tuple([&](auto i) {
return d_win.cached_coords_[i].get_offset(); },
number<decltype(d_win)::NumAccess_NonLinear>{});
#else
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4;
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
constexpr index_t num_offsets_ = Nr_ * Kr0_;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_;
return generate_tuple(
[&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * kargs.intermediate_size * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
#endif
auto o_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<a_coords.size()>{});
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
}();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
kargs.stride_token,
BlockShape::Block_Kr0 * BlockShape::Block_W0);
sweep_tile(acc_0,
[&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
smem,
kargs.hidden_size,
w_scale,
BlockShape::Block_Kr0 * BlockShape::Block_W0,
kargs.stride_token);
}
};
} // namespace ck_tile
...@@ -18,6 +18,7 @@ enum class WGAttrCtlEnum ...@@ -18,6 +18,7 @@ enum class WGAttrCtlEnum
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
}; };
...@@ -38,6 +39,28 @@ enum class WGAttrCtlEnum ...@@ -38,6 +39,28 @@ enum class WGAttrCtlEnum
:); \ :); \
} }
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16 // FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...@@ -72,22 +95,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -72,22 +95,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
const BVecType& b_vec, const BVecType& b_vec,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "a", "v")
}
else else
{ {
#if defined(__gfx9__) #if defined(__gfx9__)
...@@ -147,22 +155,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -147,22 +155,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
const BVecType& b_vec, const BVecType& b_vec,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "a", "v")
}
else else
{ {
#if defined(__gfx9__) #if defined(__gfx9__)
...@@ -223,22 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -223,22 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
const BVecType& b_vec, const BVecType& b_vec,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "a", "v")
}
else else
{ {
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
...@@ -324,23 +302,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -324,23 +302,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
const BVecType& b_vec, const BVecType& b_vec,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "a", "v")
}
else
{ {
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
...@@ -623,22 +585,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 ...@@ -623,22 +585,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
const BVecType& b_vec, const BVecType& b_vec,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "a", "v")
}
else else
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment