"tests/pipelines/pixart_sigma/test_pixart.py" did not exist on "78be400761007b346f6d600c14343752d6c5ef2e"
Commit 06914eed authored by carlushuang's avatar carlushuang
Browse files

block-asm

parent b0dd570a
......@@ -12,6 +12,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
......
......@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -33,11 +33,12 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::Gelu, // TODO: hardcoded
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape,
f_traits>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
......
......@@ -8,10 +8,7 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
......
......@@ -62,6 +62,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......
......@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
......@@ -180,6 +181,33 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "=s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(ROUND_BIAS_FOR_BF16), [v_bhi] "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
......@@ -213,6 +241,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}
......
......@@ -624,6 +624,40 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
// return [m0_init_value, size_per_issue]
// m0_init_value-> directly use this to set m0 value
// size_per_issue-> direclty use this to inc m0 every issue
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_smem_info(LdsTileWindow_&& lds_tile) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access = -1,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
#if 1
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
#else
#define GET_ASYNC_STORE_SMEM_INFO(lds_tile__) \
[&](auto lds_tile_) { \
using LdsTileWindow = remove_cvref_t<decltype(lds_tile_)>; \
using LdsDataType = typename LdsTileWindow::DataType; \
\
/* issues * warps * lanes */ \
static_assert(LdsTileWindow::get_num_of_dimension() == 3); \
\
const index_t size_per_buf_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType); \
\
const index_t size_per_wave_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<0>{}, number<1>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t size_per_issue_ = \
lds_tile_.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( \
make_tuple(number<1>{}, number<0>{}, number<0>{})) * \
sizeof(LdsDataType) - \
size_per_buf_; \
\
const index_t m0_init_value_ = size_per_buf_ + size_per_wave_ * get_warp_id(); \
\
return make_tuple(m0_init_value_, size_per_issue_); \
}(lds_tile__)
#endif
} // namespace ck_tile
......@@ -572,6 +572,49 @@ struct FastGelu
}
};
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 128;
static constexpr index_t Block_K = 512;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 16;
static constexpr index_t BlockSize = 256;
static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K * KPack; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / (Warp_K * KPack); // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
CK_TILE_LDS_ADDR void* smem,
// OWindow& o_window_,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t stride_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t stride_o)
{
// auto cached_coords_b = b_window_.cached_coords_;
// auto res_b =
// b_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto
// cached_coords_o = o_window_.cached_coords_; auto res_o =
// o_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t stride_b_bytes = stride_b * sizeof(BDataType);
const index_t stride_o_bytes = stride_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
int sfl_sst = (threadIdx.x % 16 * 4 + 4) * (threadIdx.x / 16);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
_Pragma("clang diagnostic push");
_Pragma("clang diagnostic ignored \"-Winline-asm\"");
asm volatile(
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt vmcnt(24) \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[16:17], v[128:129], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[18:19], v[130:131], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[20:21], v[132:133], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[22:23], v[134:135], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[24:25], v[136:137], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[26:27], v[138:139], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[28:29], v[140:141], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[30:31], v[142:143], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[16:17], v[192:193], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[18:19], v[194:195], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[20:21], v[196:197], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[22:23], v[198:199], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[24:25], v[200:201], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[26:27], v[202:203], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[28:29], v[204:205], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[30:31], v[206:207], [%[c12],%[c13],%[c14],%[c15]] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[32:33], v[144:145], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[34:35], v[146:147], [%[c0],%[c1],%[c2],%[c3]] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[36:37], v[148:149], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[38:39], v[150:151], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[40:41], v[152:153], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[42:43], v[154:155], [%[c0],%[c1],%[c2],%[c3]] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[44:45], v[156:157], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0],%[c1],%[c2],%[c3]], acc[46:47], v[158:159], [%[c0],%[c1],%[c2],%[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[32:33], v[208:209], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[34:35], v[210:211], [%[c4],%[c5],%[c6],%[c7]] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[36:37], v[212:213], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[38:39], v[214:215], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[40:41], v[216:217], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[42:43], v[218:219], [%[c4],%[c5],%[c6],%[c7]] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[44:45], v[220:221], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4],%[c5],%[c6],%[c7]], acc[46:47], v[222:223], [%[c4],%[c5],%[c6],%[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[48:49], v[144:145], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[50:51], v[146:147], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[52:53], v[148:149], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[54:55], v[150:151], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[56:57], v[152:153], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[58:59], v[154:155], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[60:61], v[156:157], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[62:63], v[158:159], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[48:49], v[208:209], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[50:51], v[210:211], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[52:53], v[212:213], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[54:55], v[214:215], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[56:57], v[216:217], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[58:59], v[218:219], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[60:61], v[220:221], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[62:63], v[222:223], [%[c12],%[c13],%[c14],%[c15]] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[80:81], v[160:161], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[82:83], v[162:163], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[84:85], v[164:165], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[86:87], v[166:167], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[88:89], v[168:169], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[90:91], v[170:171], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[92:93], v[172:173], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[94:95], v[174:175], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[80:81], v[224:225], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[82:83], v[226:227], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[84:85], v[228:229], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[86:87], v[230:231], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[88:89], v[232:233], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[90:91], v[234:235], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[92:93], v[236:237], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[94:95], v[238:239], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[112:113], v[176:177], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[114:115], v[178:179], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[116:117], v[180:181], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[118:119], v[182:183], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[120:121], v[184:185], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[122:123], v[186:187], [%[c8],%[c9],%[c10],%[c11]] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[124:125], v[188:189], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c8],%[c9],%[c10],%[c11]], acc[126:127], v[190:191], [%[c8],%[c9],%[c10],%[c11]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[112:113], v[240:241], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[114:115], v[242:243], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[116:117], v[244:245], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[118:119], v[246:247], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[120:121], v[248:249], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[122:123], v[250:251], [%[c12],%[c13],%[c14],%[c15]] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[124:125], v[252:253], [%[c12],%[c13],%[c14],%[c15]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c12],%[c13],%[c14],%[c15]], acc[126:127], v[254:255], [%[c12],%[c13],%[c14],%[c15]] \n"
// " s_add_u32 s60, 0x00000100, s80 \n"
// " s_cmp_lt_u32 s60, s81 \n"
// " s_cselect_b32 s56, s56, 0 \n"
// " s_add_u32 s12, s56, s12 \n"
// " s_addc_u32 s13, 0, s13 \n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n"
" v_cmp_u_f32 s[32:33], %[c0], %[c0] \n"
" v_add3_u32 v50, %[c0], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c1], %[c1] \n"
" v_add3_u32 v50, %[c1], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c0], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c2], %[c2] \n"
" v_add3_u32 v50, %[c2], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c3], %[c3] \n"
" v_add3_u32 v50, %[c3], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c1], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c4], %[c4] \n"
" v_add3_u32 v50, %[c4], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c5], %[c5] \n"
" v_add3_u32 v50, %[c5], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c2], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c6], %[c6] \n"
" v_add3_u32 v50, %[c6], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c7], %[c7] \n"
" v_add3_u32 v50, %[c7], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c3], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c8], %[c8] \n"
" v_add3_u32 v50, %[c8], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c9], %[c9] \n"
" v_add3_u32 v50, %[c9], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c4], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c10], %[c10] \n"
" v_add3_u32 v50, %[c10], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c11], %[c11] \n"
" v_add3_u32 v50, %[c11], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c5], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c12], %[c12] \n"
" v_add3_u32 v50, %[c12], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c13], %[c13] \n"
" v_add3_u32 v50, %[c13], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c6], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c14], %[c14] \n"
" v_add3_u32 v50, %[c14], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c15], %[c15] \n"
" v_add3_u32 v50, %[c15], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c7], v55, v54, s52 \n"
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
//" s_mov_b64 exec, s[16:17] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c0], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[18:19] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c1], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[20:21] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c2], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c3], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c4], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c5], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[28:29] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c6], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c7], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_stride_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
//" s_addk_i32 s80, 0x0080 \n"
//" s_cmp_lt_i32 s80, s81 \n"
//" s_cbranch_scc0 label_0E98 \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]] \n"
" v_mfma_f32_16x16x16_bf16 [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]] \n"
// " s_add_u32 s60, 0x00000100, s80 \n"
// " s_cmp_lt_u32 s60, s81 \n"
// " s_cselect_b32 s56, s56, 0 \n"
// " s_add_u32 s12, s56, s12 \n"
// " s_addc_u32 s13, 0, s13 \n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
" v_cmp_u_f32 s[32:33], %[c16], %[c16] \n"
" v_add3_u32 v50, %[c16], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c17], %[c17] \n"
" v_add3_u32 v50, %[c17], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c16], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c18], %[c18] \n"
" v_add3_u32 v50, %[c18], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c19], %[c19] \n"
" v_add3_u32 v50, %[c19], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c17], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c20], %[c20] \n"
" v_add3_u32 v50, %[c20], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c21], %[c21] \n"
" v_add3_u32 v50, %[c21], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c18], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c22], %[c22] \n"
" v_add3_u32 v50, %[c22], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c23], %[c23] \n"
" v_add3_u32 v50, %[c23], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c19], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c24], %[c24] \n"
" v_add3_u32 v50, %[c24], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c25], %[c25] \n"
" v_add3_u32 v50, %[c25], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c20], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c26], %[c26] \n"
" v_add3_u32 v50, %[c26], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c27], %[c27] \n"
" v_add3_u32 v50, %[c27], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c21], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c28], %[c28] \n"
" v_add3_u32 v50, %[c28], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c29], %[c29] \n"
" v_add3_u32 v50, %[c29], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c22], v55, v54, s52 \n"
" ;------------------------------ \n"
" v_cmp_u_f32 s[32:33], %[c30], %[c30] \n"
" v_add3_u32 v50, %[c30], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[32:33] \n"
" v_cmp_u_f32 s[32:33], %[c31], %[c31] \n"
" v_add3_u32 v50, %[c31], %[v_nan_lo], 1 \n"
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[32:33] \n"
" v_perm_b32 %[c23], v55, v54, s52 \n"
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
//" s_mov_b64 exec, s[16:17] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c16], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[18:19] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c17], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[20:21] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c18], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c19], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c20], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c21], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[28:29] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c22], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c23], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_stride_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_stride_o]"s"(stride_o_bytes),
[s_stride_b]"s"(stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
// "s32", "s33",
"v50", "v54", "v55",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
_Pragma("clang diagnostic pop");
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 512;
static constexpr index_t Block_K = 128;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t NumWarps = 4;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
static constexpr index_t BlockSize = 256;
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
using CDataType = float;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter));
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
// TODO: every wave load the same data!
static_assert(Block_K % (kABKLane * KPack_) == 0);
constexpr index_t issue_along_k = Block_K / (kABKLane * KPack_); // 4
constexpr index_t issue_along_m = Block_M / (kAMLane); // 2
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<issue_along_m>{}, // m0
number<kAMLane>{}, // m1
number<issue_along_k>{}, // k0
number<kABKLane>{}, // k1
number<KPack_>{}), // k2
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1
number<1>{}), // k2
number<KPack_>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<issue_along_m>{}, number<kAMLane>{})),
make_merge_transform(make_tuple(
number<issue_along_k>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
}
}
static constexpr auto GetGemm_AWarpEnc()
{
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
using enc_ = tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
return enc_{};
}
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
#if 0
template <typename AWindow, typename BWindow, typename SmemWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const BWindow& b_window_,
SmemWindow& smem_window_,
index_t k,
index_t stride_a,
index_t stride_b) // stride b is fixed to blockKr * blockW, but still can adjust
#else
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t stride_a,
index_t stride_b) // stride b is fixed to blockKr * blockW, but still can adjust
#endif
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR bf16_t*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR bf16_t*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t stride_a_bytes = stride_a * sizeof(bf16_t);
const index_t stride_b_bytes = stride_b * sizeof(bf16_t);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(bf16_t)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
// clang-format off
_Pragma("clang diagnostic push");
_Pragma("clang diagnostic ignored \"-Winline-asm\"");
asm volatile(
"s_mov_b32 s16, %[s_res_a0] \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_stride_a], 0 \n"
"s_add_u32 s16, s86, s16 \n"
"s_addc_u32 s17, 0, s17 \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_stride_a], 0 \n"
"s_add_u32 s16, s86, s16 \n"
"s_addc_u32 s17, 0, s17 \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_stride_b], 0 \n"
"s_add_u32 s20, s86, s20 \n"
"s_addc_u32 s21, 0, s21 \n"
"s_waitcnt vmcnt(40)\n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_stride_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_stride_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
""
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(bf16_t))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(bf16_t))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(bf16_t))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(bf16_t))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(bf16_t))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(bf16_t))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(bf16_t))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(bf16_t))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(bf16_t))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(bf16_t))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(bf16_t))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(bf16_t))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(bf16_t))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(bf16_t))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(bf16_t))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(bf16_t))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(bf16_t))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_stride_a]"s"(stride_a_bytes),
[s_stride_b]"s"(stride_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
_Pragma("clang diagnostic pop");
// clang-format on
(void)smem_buf_size;
(void)sld_os;
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
} // namespace ck_tile
......@@ -8,6 +8,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
......@@ -133,7 +133,8 @@ struct FusedMoeGemmKernel
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool UseUK = true;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
......@@ -211,157 +212,179 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size);
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 =
kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id =
__builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id =
a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
}
}
};
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
......@@ -318,6 +319,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0,
......@@ -556,7 +559,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr index_t KPad = 0; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......@@ -573,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = KVector; // pad between warps
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......@@ -589,7 +592,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
......@@ -716,7 +719,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
......@@ -812,5 +815,31 @@ struct FusedMoeGemmPipelineFlatmmPolicy
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16{};
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmUk
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_flatmm_uk";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
return MRepeat;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
auto base_coord = threadIdx.x / KLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords,
const IndexDataType* sorted_token_ids_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
});
return row_ids;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16
constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M;
auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset;
array<index_t, WarpGemmRepeat_M> coords;
static_for<0, WarpGemmRepeat_M, 1>{}(
[&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
const TopkWeightDataType* sorted_weight_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
return w;
}
CK_TILE_DEVICE auto GetRowCoords_O()
{
constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / NLans;
constexpr index_t MRepeat = BlockShape::Block_M1 / MLans;
auto base_coord = threadIdx.x / NLans;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
/*
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
*/
template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID_A(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
const auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
const auto g_window_ =
make_tile_window_linear(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{},
number<1>{});
const auto d_window_ =
make_tile_window_linear(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
return d_window_;
}();
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
#if 0
auto d_coords = generate_tuple([&](auto i) {
return d_win.cached_coords_[i].get_offset(); },
number<decltype(d_win)::NumAccess_NonLinear>{});
#else
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4;
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
constexpr index_t num_offsets_ = Nr_ * Kr0_;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_;
return generate_tuple(
[&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * kargs.intermediate_size * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
#endif
auto o_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<a_coords.size()>{});
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
}();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
kargs.stride_token,
BlockShape::Block_Kr0 * BlockShape::Block_W0);
sweep_tile(acc_0,
[&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
smem,
kargs.hidden_size,
w_scale,
BlockShape::Block_Kr0 * BlockShape::Block_W0,
kargs.stride_token);
}
};
} // namespace ck_tile
......@@ -18,6 +18,7 @@ enum class WGAttrCtlEnum
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
......@@ -38,6 +39,28 @@ enum class WGAttrCtlEnum
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
......@@ -72,22 +95,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
else
{
#if defined(__gfx9__)
......@@ -147,22 +155,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__)
......@@ -223,22 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
......@@ -324,23 +302,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "a", "v")
}
else
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
......@@ -623,22 +585,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "a", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "a", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
DISPATCH_MFMA_("v_mfma_i32_32x32x16_i8", "+v", "v", "a", "v")
}
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment