"docs/vscode:/vscode.git/clone" did not exist on "231e830084d62c8c3cc9c744b4ff74ef2280bbdc"
Commit ef5e60f6 authored by illsilin's avatar illsilin
Browse files

Merge branch 'develop' into gfx950

parents 2cc0fa26 5e93fa9e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 128;
static constexpr index_t Block_K = 512;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32;
static constexpr index_t BlockSize = 256;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
the files under this folder should not be included directly!
\ No newline at end of file
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
# define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt 0 \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], %[c15]]\n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n"
_UK_PK_CVT_("%[c0]", "%[c1]", "%[c0]")
_UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]")
_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]")
_UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]")
_UK_PK_CVT_("%[c8]", "%[c9]", "%[c4]")
_UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]")
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]")
_UK_PK_CVT_("%[c14]", "%[c15]", "%[c7]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]]\n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]")
_UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]")
_UK_PK_CVT_("%[c20]", "%[c21]", "%[c18]")
_UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]")
_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]")
_UK_PK_CVT_("%[c26]", "%[c27]", "%[c21]")
_UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]")
_UK_PK_CVT_("%[c30]", "%[c31]", "%[c23]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c16], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c17], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c18], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c19], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c20], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c21], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c22], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c23], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#endif
"s_mov_b32 s16, %[s_res_a0] \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
" s_nop 2 \n"
#undef _UK_MFMA_
......@@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -470,7 +470,8 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
template <bool Cond = kIsGroupMode>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
......@@ -484,9 +485,129 @@ struct FmhaBwdDQDKDVKernel
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
......@@ -513,13 +634,134 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -616,6 +858,208 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
......
......@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
......@@ -267,50 +267,50 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -399,8 +399,9 @@ struct FmhaFwdKernel
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -408,9 +409,99 @@ struct FmhaFwdKernel
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
......@@ -431,13 +522,104 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -522,15 +704,173 @@ struct FmhaFwdKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......
......@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
......
......@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......@@ -46,8 +47,7 @@ struct FmhaFwdSplitKVKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel
float scale_p;
};
struct PageBlockTableKargs
struct CommonPageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs
{
bool is_gappy = false;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
......@@ -193,13 +198,15 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
{
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
// single kcache page-block
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
// single vcache page-block
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
};
......@@ -212,14 +219,17 @@ struct FmhaFwdSplitKVKernel
AlibiKargs,
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
// for single kcache page-block
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
// for single vcache page-block
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
......@@ -230,8 +240,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_acc_ptr,
void* o_acc_ptr,
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
......@@ -352,8 +364,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_acc_ptr,
void* o_acc_ptr,
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
......@@ -363,6 +377,10 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
bool is_gappy,
float scale_s,
float scale_p,
ck_tile::index_t stride_q,
......@@ -416,6 +434,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
......@@ -443,6 +462,13 @@ struct FmhaFwdSplitKVKernel
{
kargs.scale_p = scale_p;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
kargs.is_gappy = is_gappy;
}
return kargs;
}
......@@ -476,11 +502,13 @@ struct FmhaFwdSplitKVKernel
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_k = 0; // unused for paged-kvcache
long_index_t batch_offset_v = 0; // unused for paged-kvcache
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
index_t kv_l2p_offset =
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
if constexpr(kIsGroupMode)
{
......@@ -490,7 +518,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
......@@ -525,6 +552,15 @@ struct FmhaFwdSplitKVKernel
{
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
if constexpr(kIsPagedKV)
{
if(kargs.is_gappy)
{
// seqstart_k_ptr has different meaning in this case
kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
}
}
}
else
{
......@@ -570,9 +606,9 @@ struct FmhaFwdSplitKVKernel
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
......@@ -587,14 +623,14 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}
}();
......@@ -609,7 +645,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
......@@ -642,7 +678,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<kPadHeadDimV, false>{});
}
else
{
......@@ -656,7 +692,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<false, kPadSeqLenK>{});
}
};
const auto v_dram = [&]() {
......@@ -677,7 +713,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
......@@ -685,14 +721,15 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
kargs.batch_stride_k, // kcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
(kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
......@@ -707,7 +744,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
......@@ -715,14 +752,15 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
kargs.batch_stride_v, // vcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
(kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
......@@ -766,9 +804,8 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......@@ -870,6 +907,7 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
kv_l2p_offset,
smem_ptr);
}
else
......@@ -886,6 +924,7 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
kv_l2p_offset,
smem_ptr);
}
}();
......
......@@ -18,11 +18,11 @@ struct FmhaFwdSplitKVTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
......
......@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
......@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
......@@ -142,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
static_assert(
......@@ -210,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
if(original_num_total_loop <= 0)
const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
......@@ -238,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
// make sure the first tile is completely located in page-block
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
......@@ -378,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
// position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
});
});
}
......@@ -396,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
});
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
......@@ -427,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
......@@ -658,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
......@@ -680,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask,
position_encoding,
scale_s,
kv_l2p_offset,
smem_ptr);
}
};
......
......@@ -331,7 +331,8 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
......@@ -355,6 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
......@@ -386,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
......@@ -514,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
......@@ -618,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
......@@ -665,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
......
......@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasUnevenSplits_,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace ck_tile {
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct FusedMoeGemmHostArgs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
const void* sorted_weight_ptr; // [max_num_tokens_padded]
const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
using DDataType = typename Pipeline::Problem::DDataType;
using AccDataType = typename Pipeline::Problem::AccDataType;
using ODataType = typename Pipeline::Problem::ODataType;
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool UseUK = true;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = BlockShape;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, GDataType>) {
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
}
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#undef _SS_
#undef _TS_
// clang-format on
}
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using Kargs = FusedMoeGemmKargs;
using Hargs = FusedMoeGemmHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
// TODO: hargs/kargs not guranteed to be the same
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if(sorted_tile_id >= num_sorted_tiles)
return;
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 =
kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id =
__builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id =
a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template <typename BlockTile_0_,
typename WarpPerBlock_0_,
typename WarpTile_0_,
typename BlockTile_1_,
typename WarpPerBlock_1_,
typename WarpTile_1_>
struct FusedMoeGemmShape
{
using BlockTile_0 = remove_cvref_t<BlockTile_0_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_0_>;
using WarpTile_0 = remove_cvref_t<WarpTile_0_>;
using BlockTile_1 = remove_cvref_t<BlockTile_1_>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_1_>;
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0;
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0;
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0;
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
static constexpr index_t ThreadPerBlock_K1 = Warp_K1 * WarpPerBlock_K1;
static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1;
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = warpSize * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);
static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1);
// static_assert(Block_Nr0 == Block_Kr1);
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename BlockShape_>
struct FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "lin";
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*intermediate_size*/)
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
return ck_tile::make_tuple(i_m, i_n);
}
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
{
// TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
return dim3(ns, ms, 1);
}
};
} // namespace ck_tile
......@@ -12,20 +12,77 @@
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
struct MoeSortingHostArgs
{
const void* p_topk_ids;
const void* p_weights;
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
// we fused the setzero of output of fused-moe buffer
// set this pointer to nullptr will skip this operation
void* p_moe_buf;
index_t tokens;
index_t unit_size;
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
index_t moe_buf_bytes;
index_t moe_buf_bytes; // byte size of p_moe_buf
};
template <typename Problem_>
......@@ -183,8 +240,14 @@ struct MoeSortingKernel
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
p_sorted_token_ids[rank_post_pad] = MOE_SORTING_MOCK_ID(curr_token_id, curr_topk_id);
#else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
p_sorted_weights[rank_post_pad] = weights[i];
#endif
p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
}
......@@ -195,8 +258,13 @@ struct MoeSortingKernel
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
}
}
......@@ -229,4 +297,7 @@ struct MoeSortingKernel
smem);
}
};
#undef MOE_SORTING_MOCK_ID
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmEx
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_flatmm";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
{
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
constexpr auto NEG1 = number<-1>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto TRUE = bool_constant<true>{};
constexpr auto FALSE = bool_constant<false>{};
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
Policy::template GetSmemSize_A<Problem>());
auto g_view = g_window_.get_bottom_tensor_view();
auto u_view = [&]() {
if constexpr(IsGateOnly)
{
return g_view;
}
else
{
index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
const GDataType* g_ptr =
g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
u_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
const auto u_view_1_ =
pad_tensor_view(u_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
return u_view_1_;
}
}();
auto a_win = make_tile_window_linear(
a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_win =
make_tile_window_linear(g_window_,
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
auto d_win =
make_tile_window_linear(d_window_,
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
auto o_win = make_tile_window_linear(
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
using g_thread_type = decltype(load_tile(g_win));
using d_thread_type = decltype(load_tile(d_win));
using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// issues_warps_lanes
auto a_sst_win0 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
auto a_sst_win1 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_sld_win0 = [&]() {
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
sequence<BlockShape::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
// m*k
auto a_sld_win1 = [&]() {
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
sequence<BlockShape::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
}();
auto bridge_sld_win = [&]() {
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsLoadDesc<Problem>()),
Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
}();
// also OK with C array, 2 register buffer
statically_indexed_array<g_thread_type, 2> gs;
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{};
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr auto issues_gemm0 =
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
warp_gemm_0.get_num_of_access()>{};
constexpr auto issues_gemm1 =
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
warp_gemm_1.get_num_of_access()>{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const index_t num_blocks_k0 =
(hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
const index_t num_blocks_n1 =
(hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
using a_thread_type = decltype(load_tile(a_sld_win0));
statically_indexed_array<a_thread_type, 2> as;
auto gld_a = [&]<typename PreNop = bool_constant<false>>(
auto& a_store_, auto i_access, PreNop = {})
{
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
};
auto move_a = [&]() {
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
};
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access);
};
auto gld_g = [&]<typename PreNop = bool_constant<false>>(
auto& g_, auto i_access, PreNop = {})
{
if constexpr(IsGateOnly)
{
// TODO: hack!
if constexpr(i_access.value == 0)
{
g_win.bottom_tensor_view_ = g_view;
}
else if constexpr(i_access.value == issues_g / 2)
{
g_win.bottom_tensor_view_ = u_view;
}
}
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
};
auto move_g = [&]() {
move_tile_window(g_win, {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
};
statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
auto& d_, auto i_access, PreNop = {})
{
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
};
auto move_d = [&]() {
// d move along gemm-n
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
};
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {})
{
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
};
auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
auto acc_1s = generate_tuple(
[&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
// clang-format off
auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor w_b;
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor w_c;
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
w_c.get_thread_buffer());
};
// clang-format on
// clang-format off
auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor w_b;
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor w_c;
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
w_c.get_thread_buffer());
};
// clang-format on
_Pragma("clang diagnostic pop");
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto pipeline_gemm0 = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_sld_a_0 = MAKE_SC();
constexpr auto c_gld_a_0 = MAKE_SC();
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
constexpr auto c_sld_a_1 = MAKE_SC();
constexpr auto c_gld_a_1 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
};
auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1);
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
constexpr auto last_nop = [&]() {
if constexpr(i_issue == (total_loops - 1))
return TRUE;
else
return FALSE;
}();
gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
});
};
auto y = Policy::template MakeYBlockTile<Problem>();
auto pipeline_bridge = [&]() {
// cast to Y data
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
clear_tile(acc_1s(I0));
// wave_barrier();
load_tile(y, bridge_sld_win);
clear_tile(acc_1s(I1));
};
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto pipeline_gemm1 = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
constexpr auto c_gst_o_0 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
constexpr auto c_gst_o_1 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
move_d();
// move_o();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
}
});
move_d();
};
auto pipeline_gemm1_head = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_d();
};
auto pipeline_gemm1_tail = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gst_o_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
{
auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, NEG1);
}
};
// start of pipeline
// clang-format off
gld_a(a_sst_win0, NEG1, TRUE);
gld_g(gs[I0], NEG1, TRUE);
move_a();
move_g();
clear_tile(acc_0);
// preload for next round
gld_a(a_sst_win1, NEG1);
gld_g(gs[I1], NEG1);
// make sure a,g loaded
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while(i_0++ < iters_0)
{
pipeline_gemm0();
}
pipeline_gemm0_tail();
pipeline_bridge();
const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head();
while(i_1++ < iters_1)
{
pipeline_gemm1();
}
pipeline_gemm1_tail();
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
struct FusedMoeGemmPipelineFlatmmPolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO: always 1 dword
return 1;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
{
if constexpr(Problem::Traits::OAtomic == 1)
{
// pack fp16/bf16 atomic
static_assert(sizeof(typename Problem::ODataType) == 2);
return 2;
}
else if constexpr(Problem::Traits::OAtomic == 2)
{
// fp32 atomic
return 1;
}
else
{
return 16 / sizeof(typename Problem::ODataType);
}
}
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<DataType_>);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{
return GetSmemKPack<typename Problem::ADataType>();
}
// used for bridge LDS shuffle
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
{
// TODO: this should match mfma layout
return 16 / sizeof(typename Problem::YDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
return a_sld_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
static_assert(bridge_sld_desc.get_element_space_size() ==
bridge_sst_desc.get_element_space_size());
return bridge_sld_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t a_lds = GetSmemSize_A<Problem>();
constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
return max(a_lds, bridge_lds);
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// optimized version for async, not same as simple MXK dist(pay attention!!)
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() <= K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
// distribution
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <index_t WarpPerBlock_N_,
index_t WarpPerBlock_K_,
index_t Repeat_N_,
index_t Repeat_K_,
index_t WarpSize_,
index_t Alignment_>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
sequence<Repeat_K_, WarpPerBlock_K_>,
sequence<WarpSize_, Alignment_>>,
tuple<sequence<1, 2>, sequence<3>>,
tuple<sequence<1, 1>, sequence<0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
constexpr index_t Alignment_ = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<Block_M_,
Block_K_,
NumWarps_,
Alignment_>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0,
S_::Repeat_K0,
get_warp_size(),
GetAlignment_G<Problem>()>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
S_::WarpPerBlock_K1,
S_::Repeat_N1,
S_::Repeat_K1,
get_warp_size(),
GetAlignment_D<Problem>()>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc()
{
constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t KPack = kABKPerLane;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<Repeat_M>{}, // m
number<Repeat_N>{}, // n
number<WarpPerBlock_N>{}, // n
number<kABKLane>{}, // n
number<kAMLane>{}, // m
number<KPack>{}), // n
make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m
number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n
number<kABKLane * kAMLane * KPack>{}, // n
number<kAMLane * KPack>{}, // n
number<KPack>{}, // m
number<1>{}), // n
number<KPack>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(make_tuple(number<Repeat_N>{},
number<WarpPerBlock_N>{},
number<kABKLane>{},
number<KPack>{}))),
make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
{
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
auto y_block_tensor =
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x512x128_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
}
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment