Commit aef2b33c authored by coderfeli's avatar coderfeli
Browse files

build ok

parent 075a4a43
......@@ -11,9 +11,9 @@ template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// template float fused_moegemm_<
// fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
// >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
......
......@@ -10,9 +10,9 @@
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// template float fused_moegemm_<
// fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
// >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
......
......@@ -304,7 +304,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<GDataType> g_perm_host = gate_only? shuffle_moe_weight(g_host, prec_w, 1) : shuffle_moe_weight_gateup(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
......
......@@ -6,6 +6,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -57,7 +57,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Block_Kr = Block_K / Warp_K; // 16
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
......@@ -89,6 +89,32 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return c_block_tensor;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockDistGUMerge()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N / 2, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTileGUMerge()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDistGUMerge();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x256_1x4x1_16x16x32_Base
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 128;
static constexpr index_t Block_K = 256;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32;
static constexpr index_t BlockSize = 256;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 8
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
}
};
} // namespace ck_tile
......@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp"
namespace ck_tile {
......@@ -14,7 +14,7 @@ namespace ck_tile {
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
......@@ -118,7 +118,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#include "uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
// [s_loop_cnt]"+s"(loop_cnt),
......@@ -181,10 +181,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
// [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
// [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
// [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
// [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
......@@ -262,7 +262,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
......@@ -288,7 +288,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
......@@ -365,7 +365,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc"
#include "uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(n),
......
......@@ -33,7 +33,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Block_Kr = Block_K / Warp_K; // 16
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
......
......@@ -54,22 +54,6 @@
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
// " ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
// " ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
// " ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
// " ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
// " ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
// " ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
// " ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
// " ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
// " ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
// " ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
// " ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
// " ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
// " ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
// " ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
// " ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
// " ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
......@@ -86,39 +70,6 @@
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
// " ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
// " ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
// " ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
// " ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
// " ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
// " ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
// " ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
// " ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
// " ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
// " ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
// " ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
// " ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
// " ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
// " ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
// " ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
// " ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
// " ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
// " ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
// " ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
// " ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
// " ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
// " ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
// " ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
// " ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
// " ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
// " ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
// " ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
// " ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
// " ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
// " ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
// " ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
// " ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
......@@ -136,22 +87,6 @@
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
// " buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
// " buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
// " buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
// " buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
// " buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
// " buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
// " buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
// " buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
// " buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
// " buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
// " buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
// " buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
// " buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
// " buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
// " buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
// " buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_mov_b32 v64, 0 \n"
......@@ -281,97 +216,13 @@
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_ " %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(0) \n"
// _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], v[64:67] \n"
// " buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], v[64:67] \n"
// " buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], v[64:67] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], v[68:71] \n"
// " buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], v[68:71] \n"
// " buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], v[68:71] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_ " %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
// _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], v[72:75] \n"
// " buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], v[72:75] \n"
// " buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], v[72:75] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], v[76:79] \n"
// " buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], v[76:79] \n"
// " buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_ " %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(0) \n"
// _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], v[64:67] \n"
// " buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], v[64:67] \n"
// " buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], v[64:67] \n" _UK_MFMA_
// " [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], v[64:67] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], v[68:71] \n"
// " buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], v[68:71] \n"
// " buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], v[68:71] \n" _UK_MFMA_
// " [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], v[68:71] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_ " %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
// _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], v[72:75] \n"
// " buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], v[72:75] \n"
// " buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], v[72:75] \n" _UK_MFMA_
// " [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], v[72:75] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], v[76:79] \n"
// " buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], v[76:79] \n"
// " buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], v[76:79] \n" _UK_MFMA_
// " [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_ " %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_add_u32 s60, 0x00000100, s80 \n"
......@@ -555,132 +406,16 @@
" %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(0) \n"
// _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[192:193], v[160:161], v[80:83] "
// "\n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[194:195], v[162:163], v[80:83] \n"
// " buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[196:197], v[164:165], v[80:83] "
// "\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[198:199], v[166:167], "
// "v[80:83] \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[200:201], v[168:169], v[80:83] "
// "\n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[202:203], v[170:171], v[80:83] \n"
// " buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 "
// "\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[204:205], v[172:173], "
// "v[80:83] \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[206:207], v[174:175], v[80:83] "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[192:193], v[224:225], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[194:195], v[226:227], v[84:87] \n"
// " buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[196:197], v[228:229], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[198:199], v[230:231], v[84:87] "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[200:201], v[232:233], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[202:203], v[234:235], v[84:87] \n"
// " buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[204:205], v[236:237], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[206:207], v[238:239], v[84:87] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
// _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[208:209], v[160:161], v[88:91] "
// "\n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[210:211], v[162:163], v[88:91] \n"
// " buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[212:213], v[164:165], v[88:91] "
// "\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[214:215], v[166:167], "
// "v[88:91] \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[216:217], v[168:169], v[88:91] "
// "\n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[218:219], v[170:171], v[88:91] \n"
// " buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 "
// "\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[220:221], v[172:173], "
// "v[88:91] \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[222:223], v[174:175], v[88:91] "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[208:209], v[224:225], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[210:211], v[226:227], v[92:95] \n"
// " buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[212:213], v[228:229], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[214:215], v[230:231], v[92:95] "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[216:217], v[232:233], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[218:219], v[234:235], v[92:95] \n"
// " buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[220:221], v[236:237], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[222:223], v[238:239], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(0) \n"
// _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[224:225], v[176:177], v[80:83] "
// "\n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[226:227], v[178:179], v[80:83] \n"
// " buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[228:229], v[180:181], v[80:83] "
// "\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[230:231], v[182:183], "
// "v[80:83] \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[232:233], v[184:185], v[80:83] "
// "\n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[234:235], v[186:187], v[80:83] \n"
// " buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen "
// "offset:1024 \n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[236:237], "
// "v[188:189], v[80:83] \n" _UK_MFMA_
// " [%[c16], %[c17], %[c18], %[c19]], acc[238:239], v[190:191], v[80:83] "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[224:225], v[240:241], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[226:227], v[242:243], v[84:87] \n"
// " buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen "
// "offset:2048 \n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[228:229], "
// "v[244:245], v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[230:231], v[246:247], v[84:87] "
// "\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[232:233], v[248:249], "
// "v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[234:235], v[250:251], v[84:87] \n"
// " buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen "
// "offset:3072 \n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[236:237], "
// "v[252:253], v[84:87] \n" _UK_MFMA_
// " [%[c20], %[c21], %[c22], %[c23]], acc[238:239], v[254:255], v[84:87] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
// _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[240:241], v[176:177], v[88:91] "
// "\n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[242:243], v[178:179], v[88:91] \n"
// " buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen "
// "\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[244:245], v[180:181], "
// "v[88:91] \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[246:247], v[182:183], v[88:91] "
// "\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[248:249], v[184:185], "
// "v[88:91] \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[250:251], v[186:187], v[88:91] \n"
// " buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen "
// "offset:1024 \n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[252:253], "
// "v[188:189], v[88:91] \n" _UK_MFMA_
// " [%[c24], %[c25], %[c26], %[c27]], acc[254:255], v[190:191], v[88:91] "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[240:241], v[240:241], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[242:243], v[242:243], v[92:95] \n"
// " buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen "
// "offset:2048 \n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[244:245], "
// "v[244:245], v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[246:247], v[246:247], v[92:95] "
// "\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[248:249], v[248:249], "
// "v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[250:251], v[250:251], v[92:95] \n"
// " buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen "
// "offset:3072 \n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[252:253], "
// "v[252:253], v[92:95] \n" _UK_MFMA_
// " [%[c28], %[c29], %[c30], %[c31]], acc[254:255], v[254:255], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
......
......@@ -807,16 +807,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 && !Problem::Traits::IsGateOnly)
{
return Flatmm_32x256x128_1x4x1_16x16x32_BF16{};
return Flatmm_32x512x128_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 && !Problem::Traits::IsGateOnly)
{
return Flatmm_32x256x128_1x4x1_16x16x32_FP16{};
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
}
......
......@@ -199,8 +199,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
if (row_ids_a[0] >= kargs.num_tokens)
return;
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
......@@ -266,8 +265,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = BlockShape::Block_Kr1 / Kr1_; //4
constexpr index_t Kr1_ = 4;
constexpr index_t Kr0_ = BlockShape::Block_Kr1 / Kr1_; //4
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
......@@ -300,7 +299,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDistGUMerge();
// constexpr auto dist_ = IsGateOnly ? Policy::template GetUK_0<Problem>().MakeCBlockDist()
// : Policy::template GetUK_0<Problem>().MakeCBlockDistGUMerge();
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem), desc_),
desc_.get_lengths(),
......@@ -315,11 +316,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
if (row_ids_a[0] >= kargs.num_tokens)
return;
// if (row_ids_a[0] >= kargs.num_tokens)
// return;
auto uk_0_g = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0_g(a_res,
auto acc_0_full = uk_0_g(a_res,
a_coords,
g_res,
g_coords,
......@@ -328,7 +329,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
auto acc_0 = Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
if (!IsGateOnly) {
sweep_tile(acc_0, [&](auto idx0) {
acc_0(idx0) = acc_0_full(idx0);
});
}
// fast GeLu
if constexpr(std::is_same_v<typename Problem::GateActivation,
ck_tile::element_wise::FastGeluAsm>)
......@@ -350,37 +357,18 @@ struct FusedMoeGemmPipeline_FlatmmUk
[&](auto idx0) { typename Problem::GateActivation{}(acc_0(idx0), acc_0(idx0)); },
sequence<1, 1>{});
}
if (!IsGateOnly) {
for(auto i = 0; i < BlockShape::Repeat_N0; i++)
{
acc_0.get_thread_buffer()[4 * i + 0] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 0];
acc_0.get_thread_buffer()[4 * i + 1] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 1];
acc_0.get_thread_buffer()[4 * i + 2] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 2];
acc_0.get_thread_buffer()[4 * i + 3] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 3];
}
}
auto y_pre = acc_0;
block_sync_lds();
// up
// if(!IsGateOnly)
// {
// // up ptr. add hafl expoert_stride_0 as offset.
// auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
// auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
// auto u_coords =
// generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
// number<decltype(u_win)::NumAccess_NonLinear>{});
// // reuse UK0
// auto uk_0_u = Policy::template GetUK_0<Problem>();
// auto acc_0_u = uk_0_u(a_res,
// a_coords,
// u_res,
// u_coords,
// smem,
// kargs.hidden_size,
// BlockShape::Block_K0, // tile offset for B matrix each unroll
// BlockShape::Block_Kr0 *
// BlockShape::Block_W0); // tile offset for B matrix each unroll
// // elementwise mul gate*up.
// sweep_tile(
// y_pre,
// [&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
// sequence<1, 1>{});
// block_sync_lds();
// }
store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre));
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment