Unverified Commit e71aa1d6 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

unify q persistent in register (#24)

* unify q persistent in register

* add refactor warp_gemm dispatcher
parent 02d69525
#include <cstring> #include <cstring>
#include <ostream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -15,6 +16,8 @@ ...@@ -15,6 +16,8 @@
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp"
#include "ck/tile_program/tile/tile_fmha_shape.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp"
...@@ -33,10 +36,14 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm ...@@ -33,10 +36,14 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck::half_t; using ODataType = ck::half_t;
// M0 N0 K0 N1 K1 // M0 N0 K0 N1 K1 K0L
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using FmhaShape = ck::tile_program::TileFmhaShape<128, 128, 32, 128, 32>; using FmhaBlockTile = ck::Sequence<128, 128, 32, 128, 32, 128>;
using FmhaBlockWarps = ck::Sequence<4, 1, 1>;
using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaShape = ck::tile_program::
TileFmhaShape<FmhaBlockTile, FmhaBlockWarps, FmhaWarpTile, FmhaBlockWarps, FmhaWarpTile>;
using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>; using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType, using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
...@@ -49,7 +56,8 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD ...@@ -49,7 +56,8 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD
ODataType, ODataType,
256, // BlockSize 256, // BlockSize
FmhaShape>; FmhaShape>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>; // using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>; using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>; using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
...@@ -134,7 +142,7 @@ int main(int argc, char* argv[]) ...@@ -134,7 +142,7 @@ int main(int argc, char* argv[])
<< ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v
<< ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm << ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm
<< ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z << ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z
<< std::endl; << std::flush << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) // P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
#define C_LOG2E 1.44269504088896340736 // log2(e) #define C_LOG2E 1.44269504088896340736 // log2(e)
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel struct FmhaFwdKernel
...@@ -148,10 +148,16 @@ struct FmhaFwdKernel ...@@ -148,10 +148,16 @@ struct FmhaFwdKernel
Number<32>{}, Number<32>{},
Number<1>{}); Number<1>{});
auto q_dram_window = auto q_dram_window = make_tile_window(
make_tile_window(q_dram, q_dram,
make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{}), [&]() {
{i_m0, 0}); if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(Number<FmhaPipeline::kM0>{},
Number<FmhaPipeline::kK0BlockLength>{});
else
return make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0}); k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0});
......
...@@ -26,9 +26,11 @@ namespace block { ...@@ -26,9 +26,11 @@ namespace block {
// This will: // This will:
// 1. Load B from global memory into shared memory and then // 1. Load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1 // 2. Call BlockGemmARegSGmemCRegV1
template <typename Problem, typename Policy = BlockGemmARegBGmemCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmARegBGmemCRegV1DefaultPolicy>
struct BlockGemmARegBGmemCRegV1 struct BlockGemmARegBGmemCRegV1
{ {
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
...@@ -37,13 +39,9 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -37,13 +39,9 @@ struct BlockGemmARegBGmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegV1<BlockGemmARegBSmemCRegProblem<ADataType, BlockGemmARegBSmemCRegProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BDataType, BlockGemmARegBSmemCRegV1DefaultPolicy>;
CDataType,
kBlockSize,
BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
__host__ __device__ static constexpr ck::index_t GetStaticLdsSize() __host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{ {
......
...@@ -23,9 +23,11 @@ namespace block { ...@@ -23,9 +23,11 @@ namespace block {
// A is block distributed tensor // A is block distributed tensor
// B is block window on shared memory // B is block window on shared memory
// C is block distributed tensor // C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegV1 struct BlockGemmARegBSmemCRegV1
{ {
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace ck {
namespace tile_program {
namespace block {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBSmemCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::At(Number<0>{});
static constexpr index_t kNWarps = BlockWarps::At(Number<1>{});
static constexpr index_t kKWarps = BlockWarps::At(Number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
...@@ -24,9 +24,11 @@ namespace block { ...@@ -24,9 +24,11 @@ namespace block {
// A is block window on shared memory // A is block window on shared memory
// B is block window on shared memory // B is block window on shared memory
// C is block distributed tensor // C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockGemmASmemBSmemCRegV1 struct BlockGemmASmemBSmemCRegV1
{ {
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpTile_,
bool TranposeC_>
struct BlockGemmASmemBSmemCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t BlockMWarps = BlockWarps::At(Number<0>{});
static constexpr index_t BlockNWarps = BlockWarps::At(Number<1>{});
static constexpr index_t BlockKWarps = BlockWarps::At(Number<2>{});
static constexpr index_t MPerWarp = WarpTile::At(Number<0>{});
static constexpr index_t NPerWarp = WarpTile::At(Number<1>{});
static constexpr index_t KPerWarp = WarpTile::At(Number<2>{});
static constexpr bool TranposeC = TranposeC_;
using WarpGemm = ck::tile_program::warp::
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWarp, NPerWarp, KPerWarp, TranposeC>;
template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
return make_tuple(WarpGemm{}, BlockMWarps, BlockNWarps);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
...@@ -35,7 +35,8 @@ struct BlockFmhaPipelineQKVS ...@@ -35,7 +35,8 @@ struct BlockFmhaPipelineQKVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr bool kQLoadOnce = false; // if q load whole block length (hdim) at once
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
namespace ck {
namespace tile_program {
namespace block {
// This pipeline is qkv all located in LDS
template <typename Problem, typename Policy = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct BlockFmhaPipelineQRKSVS
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
__host__ __device__ static constexpr ck::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction>
__host__ __device__ auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
float scale,
index_t num_total_loop,
index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}],
"wrong!");
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
auto k_lds = make_tensor_view<AddressSpaceEnum::Lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(Number<kN0>{}, Number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window =
make_tile_window(v_lds, make_tuple(Number<kN1>{}, Number<kK1>{}), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.GetBottomTensorView(),
q_dram_block_window_tmp.GetWindowLengths(),
q_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
auto q = load_tile(q_dram_window); // persistent q register tile
auto s_acc = decltype(gemm_0(get_slice_tile(tile_elementwise_in(q_element_func, q),
Sequence<0, 0>{},
Sequence<kM0, kK0>{}),
k_lds_window)){};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
decltype(tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc));
using PBlockTileType =
decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>, s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1(
get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0, kK1>{}),
v_lds_window));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
tile_elementwise_inout([](auto& e) { e = NumericLimits<SMPLComputeDataType>::Lowest(); },
m);
tile_elementwise_inout([](auto& e) { e = 0; }, l);
auto k_dram_block_window = k_dram_block_window_tmp;
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(),
v_dram_block_window_tmp.GetWindowLengths(),
v_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
do
{
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
k_dram_block_window.GetBottomTensorView(),
k_dram_block_window.GetWindowLengths(),
k_dram_block_window.GetWindowOrigin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto k_block_tile = load_tile(k_dram_window);
{
move_tile_window(k_dram_window, {0, kK0});
tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C
store_tile(k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0
k_block_tile = load_tile(k_dram_window); // global read 1
}
// index_t i_k0_loops = num_sub_loop_qk - 2;
constexpr index_t k0_loops = kK0BlockLength / kK0;
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
Sequence<0, i_k0 * kK0>{},
Sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
Sequence<0, (k0_loops - 2) * kK0>{},
Sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
Sequence<0, (k0_loops - 1) * kK0>{},
Sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
// STAGE 2, scale softmax
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
Sequence<1>{},
f_max,
NumericLimits<SMPLComputeDataType>::Lowest()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max);
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.GetTileDistribution()); // Pcompute{j}
constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans();
sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = math::exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
move_tile_window(v_dram_window, {0, kK1});
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// STAGE 3, KV gemm
constexpr index_t k1_loops = kN0 / kK1;
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, Sequence<0, i_k1 * kK1>{}, Sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
move_tile_window(v_dram_window, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
} while(i_total_loops < num_total_loop);
// finally, O
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = 1 / l[i_idx];
sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp>
__host__ __device__ auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
float scale,
index_t num_total_loop,
index_t num_sub_loop_qk,
void* smem_ptr) const
{
return operator()(
q_dram_block_window_tmp,
[](const QDataType& x) { return x; },
k_dram_block_window_tmp,
[](const KDataType& x) { return x; },
v_dram_block_window_tmp,
[](const VDataType& x) { return x; },
scale,
num_total_loop,
num_sub_loop_qk,
smem_ptr);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQRKSVSDefaultPolicy
{
template <typename Problem, typename BlockGemm>
__host__ __device__ static constexpr auto MakeQRegBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
}
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr ck::index_t GetSmemSizeQ()
{
return 0;
}
template <typename Problem>
__host__ __device__ static constexpr ck::index_t GetSmemSize()
{
constexpr index_t smem_size_gemm_0 =
GetSmemSizeQ<Problem>() + sizeof(typename Problem::KDataType) *
MakeKLdsBlockDescriptor<Problem>().GetElementSpaceSize();
constexpr index_t smem_size_gemm_1 =
MakeVLdsBlockDescriptor<Problem>().GetElementSpaceSize() *
sizeof(typename Problem::VDataType);
// TODO: consider shuffle requirement
return math::max(smem_size_gemm_0, smem_size_gemm_1);
}
template <typename Problem, typename BlockGemm>
__host__ __device__ static constexpr auto MakeQDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<M0, M1, M2>, Sequence<K0, K1, K2>>,
Tuple<Sequence<1>, Sequence<2, 1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Sequence<2, 1, 2>,
Sequence<0, 0, 2>>{});
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeKDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(KDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<0>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<1, 1>>{});
#endif
}
template <typename Problem>
__device__ static constexpr auto MakeVDramTileDistribution()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
template <typename Problem>
__host__ __device__ static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
// using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<typename
// Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType,
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<0>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>;
using WarpGemm = warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
warp::WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
2>>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
// using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
...@@ -8,19 +8,27 @@ ...@@ -8,19 +8,27 @@
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
template <index_t kM0PerTile_, // tile size along q seqlen template <typename BlockTile_, // Sequence<...
index_t kN0PerTile_, // tile size along k seqlen typename Gemm0BlockWarps_,
index_t kK0PerTile_, // tile size along qk gemm unroll typename Gemm0WarpTile_,
index_t kN1PerTile_, // tile size along v head_dim typename Gemm1BlockWarps_,
index_t kK1PerTile_ // tile size along kv gemm unroll typename Gemm1WarpTile_>
>
struct TileFmhaShape struct TileFmhaShape
{ {
static constexpr index_t kM0 = kM0PerTile_; using BlockTile = remove_cvref_t<BlockTile_>;
static constexpr index_t kN0 = kN0PerTile_; using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
static constexpr index_t kK0 = kK0PerTile_; using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
static constexpr index_t kN1 = kN1PerTile_; using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
static constexpr index_t kK1 = kK1PerTile_; using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t kM0 = BlockTile::At(Number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::At(Number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::At(Number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::At(Number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::At(Number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kK0BlockLength =
BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
}; };
} // namespace tile_program } // namespace tile_program
......
...@@ -22,9 +22,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 = ...@@ -22,9 +22,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 =
using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>; WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
......
...@@ -287,6 +287,90 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -287,6 +287,90 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
} }
}; };
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
// swap A and B
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
using AWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kBNLane>, Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<0, 0>>,
Sequence<2>,
Sequence<1>>;
using BWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
Impl::kABKLane,
2,
Impl::kABKPerLane>,
Sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
Tuple<Sequence<2, 1, 1, 1, 1>>,
Tuple<Sequence<0, 0, 2, 1, 3>>,
Sequence<2>,
Sequence<1>>;
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCNLane>,
Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<1, 0>>,
Sequence<2, 2>,
Sequence<0, 2>>;
// c_vec += a_vec * b_vec
__device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
__device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = Number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
a_vector.template AsType<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
}
};
} // namespace warp } // namespace warp
} // namespace tile_program } // namespace tile_program
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace ck {
namespace tile_program {
namespace warp {
namespace impl {
template <typename AType,
typename BType,
typename CType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC>
struct WarpGemmMfmaDispatcher;
// clang-format off
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck::half_t, ck::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
// clang-format on
} // namespace impl
template <typename AType,
typename BType,
typename CType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC>
using WarpGemmMfmaDispatcher = typename impl::
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
} // namespace warp
} // namespace tile_program
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment