Commit b4887801 authored by carlushuang's avatar carlushuang
Browse files

tmp

parent a5670e67
......@@ -45,6 +45,34 @@ struct BlockFmhaPipelineQRAsyncEx
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t Block_M0 = BlockFmhaShape::Block_M0;
static constexpr index_t Block_N0 = BlockFmhaShape::Block_N0;
static constexpr index_t Block_K0 = BlockFmhaShape::Block_K0;
static constexpr index_t BlockWarps_M0 = BlockFmhaShape::BlockWarps_M0;
static constexpr index_t BlockWarps_N0 = BlockFmhaShape::BlockWarps_N0;
static constexpr index_t BlockWarps_K0 = BlockFmhaShape::BlockWarps_K0;
static constexpr index_t Warps_M0 = BlockFmhaShape::Warps_M0;
static constexpr index_t Warps_N0 = BlockFmhaShape::Warps_N0;
static constexpr index_t Warps_K0 = BlockFmhaShape::Warps_K0;
static constexpr index_t Repeat_M0 = BlockFmhaShape::Repeat_M0;
static constexpr index_t Repeat_N0 = BlockFmhaShape::Repeat_N0;
static constexpr index_t Repeat_K0 = BlockFmhaShape::Repeat_K0;
static constexpr index_t Block_M1 = BlockFmhaShape::Block_M1;
static constexpr index_t Block_N1 = BlockFmhaShape::Block_N1;
static constexpr index_t Block_K1 = BlockFmhaShape::Block_K1;
static constexpr index_t BlockWarps_M1 = BlockFmhaShape::BlockWarps_M1;
static constexpr index_t BlockWarps_N1 = BlockFmhaShape::BlockWarps_N1;
static constexpr index_t BlockWarps_K1 = BlockFmhaShape::BlockWarps_K1;
static constexpr index_t Warps_M1 = BlockFmhaShape::Warps_M1;
static constexpr index_t Warps_N1 = BlockFmhaShape::Warps_N1;
static constexpr index_t Warps_K1 = BlockFmhaShape::Warps_K1;
static constexpr index_t Repeat_M1 = BlockFmhaShape::Repeat_M1;
static constexpr index_t Repeat_N1 = BlockFmhaShape::Repeat_N1;
static constexpr index_t Repeat_K1 = BlockFmhaShape::Repeat_K1;
static constexpr index_t UnrollStages = 2; // pipeline unroll the gemm/softmax/gemm
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
......@@ -176,11 +204,10 @@ struct BlockFmhaPipelineQRAsyncEx
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
auto k_lds_store = [&](){
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
......@@ -189,28 +216,64 @@ struct BlockFmhaPipelineQRAsyncEx
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
}();
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>());
auto k_lds_load = make_tile_window(
k_lds_Load_view, Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
auto k_lds_load = [&](){
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr,
Policy::template MakeSmemLoadDesc_K<Problem>()),
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
}();
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), Policy::template MakeSmemLoadDesc_V<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
auto v_lds_store = [&](){
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemStoreDesc_V<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_V<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchV>{});
}();
auto v_lds_load = [&](){
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr,
Policy::template MakeSmemLoadDesc_V<Problem>()),
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
}();
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// Block GEMM
constexpr auto gemm_0 = Policy::template GetBlockGemm_0<Problem>();
constexpr auto gemm_1 = Policy::template GetBlockGemm_1<Problem>();
constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>();
constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>();
auto gemm_0 = [&](){
constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0;
// n*k*m, more relaxed ds_read
static_for<0, total_repeats, 1>{}(
[&](auto i_r){
constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
}
);
};
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeGlobalDesc_Q<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
......@@ -221,12 +284,8 @@ struct BlockFmhaPipelineQRAsyncEx
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<2>{});
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
......@@ -234,14 +293,14 @@ struct BlockFmhaPipelineQRAsyncEx
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>());
// init Oacc, M, L
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<2>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<2>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<2>{});
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
static_for<0, 2, 1>{}([&](auto i) {
static_for<0, UnrollStages, 1>{}([&](auto i) {
clear_tile(o_accs(i));
set_tile(ms(i), -numeric<SMPLComputeDataType>::infinity());
clear_tile(ls(i));
......
......@@ -102,28 +102,20 @@ struct BlockFmhaPipelineQRAsyncEx
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm_0()
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = GetWarpGemm_0<Problem>();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0, Problem::BlockFmhaShape::BlockWarps_N0>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
return t;
}
template <typename Problem>
......@@ -451,13 +443,8 @@ struct BlockFmhaPipelineQRAsyncEx
{
if constexpr(Problem::kHasDropout)
{
constexpr auto gemm_0 = QXPolicy::template GetBlockGemm_0<Problem>();
constexpr auto config =
decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kMPerStep = Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep = Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
}
......@@ -622,29 +609,20 @@ struct BlockFmhaPipelineQRAsyncEx
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm_1()
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
auto warp_gemm = GetWarpGemm_1<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1, Problem::BlockFmhaShape::BlockWarps_N1>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
return t;
}
};
......
......@@ -41,6 +41,39 @@ struct TileFmhaShape
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
// gemm-0 shapes TODO: naming?
static constexpr index_t Block_M0 = kM0;
static constexpr index_t Block_N0 = kN0;
static constexpr index_t Block_K0 = kK0;
static constexpr index_t BlockWarps_M0 = Gemm0BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N0 = Gemm0BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K0 = Gemm0BlockWarps::at(number<2>{});
static constexpr index_t Warps_M0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t Warps_N0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t Warps_K0 = Gemm0WarpTile::at(number<2>{});
static_assert(Block_M0 % (BlockWarps_M0 * Warps_M0) == 0);
static_assert(Block_N0 % (BlockWarps_N0 * Warps_N0) == 0);
static_assert(Block_K0 % (BlockWarps_K0 * Warps_K0) == 0);
static constexpr index_t Repeat_M0 = Block_M0 / (BlockWarps_M0 * Warps_M0);
static constexpr index_t Repeat_N0 = Block_N0 / (BlockWarps_N0 * Warps_N0);
static constexpr index_t Repeat_K0 = Block_K0 / (BlockWarps_K0 * Warps_K0);
static constexpr index_t Block_M1 = kM0;
static constexpr index_t Block_N1 = kN1;
static constexpr index_t Block_K1 = kK1;
static constexpr index_t BlockWarps_M1 = Gemm1BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N1 = Gemm1BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K1 = Gemm1BlockWarps::at(number<2>{});
static constexpr index_t Warps_M1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t Warps_N1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t Warps_K1 = Gemm1WarpTile::at(number<2>{});
static_assert(Block_M1 % (BlockWarps_M1 * Warps_M1) == 0);
static_assert(Block_N1 % (BlockWarps_N1 * Warps_N1) == 0);
static_assert(Block_K1 % (BlockWarps_K1 * Warps_K1) == 0);
static constexpr index_t Repeat_M1 = Block_M1 / (BlockWarps_M1 * Warps_M1);
static constexpr index_t Repeat_N1 = Block_N1 / (BlockWarps_N1 * Warps_N1);
static constexpr index_t Repeat_K1 = Block_K1 / (BlockWarps_K1 * Warps_K1);
};
template <typename BlockTile_, // sequence<...
......
......@@ -21,6 +21,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_utils.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template<typename AccWarpDescEnc,
typename BlockTile, // seq<M, N>
typename BlockWarps,
typename WarpTile>
CK_TILE_DEVICE_HOST constexpr auto make_block_gemm_acc_enc()
{
constexpr index_t Block_M = BlockTile::at(number<0>{});
constexpr index_t Block_N = BlockTile::at(number<1>{});
constexpr index_t BlockWarps_M = BlockWarps::at(number<0>{});
constexpr index_t BlockWarps_N = BlockWarps::at(number<1>{});
constexpr index_t Warp_M = WarpTile::at(number<0>{});
constexpr index_t Warp_N = WarpTile::at(number<1>{});
constexpr index_t Repeat_M = Block_M / (BlockWarps_M * Warp_M);
constexpr index_t Repeat_N = Block_N / (BlockWarps_N * Warp_N);
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, BlockWarps_M>, sequence<Repeat_N, BlockWarps_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, AccWarpDescEnc{});
return c_block_dstr_encode;
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment