"src/vscode:/vscode.git/clone" did not exist on "beb932c5d111872c5e45387e7b1b2b3dd0524a47"
Unverified Commit 9f36ac7c authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

add fmha fwd pipeline (#17)



* Revert "Extract gemm0 prefetch0 out from loop"

This reverts commit d3b56f39f9fd12edb476b24ae9cf480841d311e4.

* add fmha fwd  pipeline

* Extract gemm0 prefetch0 out from loop

* move blockSize to another place ; fix a missing header in tile_window_impl_static_distribution.hpp

* remove KArgs from tile modules

---------
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 63bc96e3
...@@ -5,3 +5,4 @@ add_example_executable(example_reduce reduce.cpp) ...@@ -5,3 +5,4 @@ add_example_executable(example_reduce reduce.cpp)
add_example_executable(example_softmax softmax.cpp) add_example_executable(example_softmax softmax.cpp)
add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp)
add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp)
add_example_executable(example_fmha_fwd fmha_fwd.cpp)
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp"
#include "ck/tile_program/tile/tile_fmha_shape.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "fmha_fwd_kernel.hpp"
#include "fmha_fwd_epilogue.hpp"
using QDataType = ck::half_t;
using KDataType = ck::half_t;
using VDataType = ck::half_t;
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck::half_t;
using FmhaShape =
ck::tile_program::TileFmhaShape<128 /*M0*/, 128 /*N0*/, 32 /*K0*/, 128 /*N1*/, 32 /*K1*/>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShape>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernel = FmhaFwdKernel<FmhaPipeline, FmhaEpilogue>;
int main(int argc, char* argv[])
{
ck::index_t Batch = 16; // batch * nheads
ck::index_t M0 = 3328; // seqlen_q
ck::index_t N0 = 4096; // seqlen_k
ck::index_t K0 = 128; // hdim_q
ck::index_t N1 = 128; // hdim_v
if(argc == 6)
{
Batch = std::stoi(argv[1]);
M0 = std::stoi(argv[2]);
N0 = std::stoi(argv[3]);
K0 = std::stoi(argv[4]);
N1 = std::stoi(argv[5]);
}
std::array<ck::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck::index_t, 3> q_strides{M0 * K0, K0, 1};
std::array<ck::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck::index_t, 3> k_strides{N0 * K0, K0, 1};
std::array<ck::index_t, 3> v_lengths{Batch, N1, N0};
std::array<ck::index_t, 3> v_strides{N1 * N0, N0, 1};
std::array<ck::index_t, 3> s_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> s_strides{M0 * N0, N0, 1};
std::array<ck::index_t, 3> p_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> p_strides{M0 * N0, N0, 1};
std::array<ck::index_t, 3> o_lengths{Batch, M0, N1};
std::array<ck::index_t, 3> o_strides{M0 * N1, N1, 1};
// host verify
Tensor<QDataType> q_host(q_lengths, q_strides);
Tensor<KDataType> k_host(k_lengths, k_strides);
Tensor<VDataType> v_host(v_lengths, v_strides);
Tensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
Tensor<PDataType> p_host_ref(p_lengths, p_strides);
Tensor<ODataType> o_host_ref(o_lengths, o_strides);
Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
#else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
#endif
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref,
p_host_ref);
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);
DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize());
DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize());
DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize());
q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
dim3 kGridSize = FmhaKernel::GridSize(Batch, M0, N1);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
std::cout << "batch:" << Batch << ", seqlen_q:" << M0 << ", seqlen_k:" << N0
<< ", hdim_q:" << K0 << ", hdim_v:" << N1 << ", grid_size " << kGridSize.x
<< std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
auto kargs = FmhaKernel::MakeKargs(q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
M0, // seqlen_q
N0, // seqlen_k
K0, // hdim_q
N1, // hdim_v
K0, // stride_q
K0, // stride_k
N0, // stride_v
N1, // stride_o
M0 * K0, // batch_stride_q
N0 * K0, // batch_stride_k
N1 * N0, // batch_stride_v
M0 * N1); // batch_stride_o
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true},
FmhaKernel{},
kGridSize,
kBlockSize,
0,
kargs); // BatchStrideO
o_buf.FromDevice(o_host_dev.mData.data());
std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype =
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !ck::utils::check_err(o_host_dev, o_host_ref);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
template <typename OaccDataType_, typename ODataType_>
struct FmhaFwdEpilogueProblem
{
using OaccDataType = ck::remove_cvref_t<OaccDataType_>;
using ODataType = ck::remove_cvref_t<ODataType_>;
};
template <typename Problem_, typename Policy_ = void>
struct FmhaFwdEpilogue
{
using Problem = ck::remove_cvref_t<Problem_>;
using OaccDataType = ck::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck::remove_cvref_t<typename Problem::ODataType>;
__host__ __device__ static constexpr ck::index_t GetSmemSize() { return 0; }
template <typename ODramWindowTmp, typename OAccTile>
__device__ auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
{
using namespace ck;
using namespace ck::tile_program;
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc_tile);
store_tile(o_dram_window_tmp, o);
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q]
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
{
using FmhaPipeline = ck::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize;
using QDataType = ck::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>;
using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>;
struct Kargs
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
ck::index_t seqlen_q;
ck::index_t seqlen_k;
ck::index_t hdim_q;
ck::index_t hdim_v;
ck::index_t stride_q;
ck::index_t stride_k;
ck::index_t stride_v;
ck::index_t stride_o;
ck::index_t batch_stride_q;
ck::index_t batch_stride_k;
ck::index_t batch_stride_v;
ck::index_t batch_stride_o;
};
__host__ static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck::index_t seqlen_q,
ck::index_t seqlen_k,
ck::index_t hdim_q,
ck::index_t hdim_v,
ck::index_t stride_q,
ck::index_t stride_k,
ck::index_t stride_v,
ck::index_t stride_o,
ck::index_t batch_stride_q,
ck::index_t batch_stride_k,
ck::index_t batch_stride_v,
ck::index_t batch_stride_o)
{
return Kargs{q_ptr,
k_ptr,
v_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
stride_q,
stride_k,
stride_v,
stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
}
__host__ static constexpr auto
GridSize(ck::index_t batch_size_, ck::index_t seqlen_q_, ck::index_t hdim_v_)
{
return dim3(batch_size_ * (seqlen_q_ / FmhaPipeline::kM0) * (hdim_v_ / FmhaPipeline::kN1));
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
__host__ __device__ static constexpr ck::index_t GetSmemSize()
{
return ck::math::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
__device__ void operator()(Kargs kargs) const
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const index_t num_tile_m0 = kargs.seqlen_q / FmhaPipeline::kM0;
const index_t num_tile_n1 = kargs.hdim_v / FmhaPipeline::kN1;
const index_t id_block = ck::get_block_id();
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck::make_tuple(quotient, modulus);
};
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
const index_t i_batch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(id_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(id_tile_n * FmhaPipeline::kN1);
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr =
reinterpret_cast<const QDataType*>(kargs.q_ptr) + i_batch * kargs.batch_stride_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) + i_batch * kargs.batch_stride_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) + i_batch * kargs.batch_stride_v;
ODataType* o_ptr =
reinterpret_cast<ODataType*>(kargs.o_ptr) + i_batch * kargs.batch_stride_o;
// Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
Number<32>{},
Number<1>{});
const auto k_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
Number<32>{},
Number<1>{});
const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
auto q_dram_window =
make_tile_window(q_dram,
make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{}),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(Number<FmhaPipeline::kN1>{}, Number<FmhaPipeline::kK1>{}),
{i_n1, 0});
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
kargs.seqlen_k / FmhaPipeline::kN0,
kargs.hdim_q / FmhaPipeline::kK0,
smem_ptr);
// O DRAM and O DRAM window
auto o_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
Number<32>{},
Number<1>{});
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace tile_program {
namespace block {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
index_t kBlockSize_,
typename BlockFmhaShape_>
struct BlockFmhaPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
namespace ck {
namespace tile_program {
namespace block {
// This pipeline is qkv all located in LDS
template <typename Problem, typename Policy = BlockFmhaPipelineQKVSDefaultPolicy>
struct BlockFmhaPipelineQKVS
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
__host__ __device__ static constexpr ck::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction>
__host__ __device__ auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
index_t num_total_loop,
index_t num_sub_loop_qk,
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kK0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] &&
kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}],
"wrong!");
// Q tile in LDS
auto q_lds = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<QDataType*>(smem_ptr),
Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(Number<kM0>{}, Number<kK0>{}), {0, 0});
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
auto k_lds = make_tensor_view<AddressSpaceEnum::Lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(Number<kN0>{}, Number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window =
make_tile_window(v_lds, make_tuple(Number<kN1>{}, Number<kK1>{}), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto s_acc = decltype(gemm_0(q_lds_window, k_lds_window)){};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
decltype(tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc));
using PBlockTileType =
decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>, s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1(
get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0, kK1>{}),
v_lds_window));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
tile_elementwise_inout([](auto& e) { e = NumericLimits<SMPLComputeDataType>::Lowest(); },
m);
tile_elementwise_inout([](auto& e) { e = 0; }, l);
auto k_dram_block_window = k_dram_block_window_tmp;
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(),
v_dram_block_window_tmp.GetWindowLengths(),
v_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeVDramTileDistribution<Problem>());
index_t i_total_loops = 0;
do
{
// STAGE 1, QK gemm
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.GetBottomTensorView(),
q_dram_block_window_tmp.GetWindowLengths(),
q_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto k_dram_window = make_tile_window(
k_dram_block_window.GetBottomTensorView(),
k_dram_block_window.GetWindowLengths(),
k_dram_block_window.GetWindowOrigin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto q_block_tile = load_tile(q_dram_window); // prefetch, global read 0
auto k_block_tile = load_tile(k_dram_window);
{
move_tile_window(q_dram_window, {0, kK0}); // move to 1
move_tile_window(k_dram_window, {0, kK0});
tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C
store_tile(q_lds_window,
tile_elementwise_in(q_element_func, q_block_tile)); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
store_tile(k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0
k_block_tile = load_tile(k_dram_window); // global read 1
}
index_t i_k0_loops = num_sub_loop_qk - 2;
do
{
block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM i
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0}); // move to i + 2
move_tile_window(k_dram_window, {0, kK0});
store_tile(q_lds_window,
tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
store_tile(k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
i_k0_loops--;
} while(i_k0_loops > 0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 2
block_sync_lds();
store_tile(
q_lds_window,
tile_elementwise_in(q_element_func, q_block_tile)); // LDS write num_loop - 1
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 1
}
// STAGE 2, scale softmax
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
Sequence<1>{},
f_max,
NumericLimits<SMPLComputeDataType>::Lowest()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max);
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.GetTileDistribution()); // Pcompute{j}
constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans();
sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = math::exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
move_tile_window(v_dram_window, {0, kK1});
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// STAGE 3, KV gemm
constexpr index_t k1_loops = kN0 / kK1;
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, Sequence<0, i_k1 * kK1>{}, Sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
move_tile_window(v_dram_window, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
} while(i_total_loops < num_total_loop);
// finally, O
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = 1 / l[i_idx];
sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp>
__host__ __device__ auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
index_t num_total_loop,
index_t num_sub_loop_qk,
void* smem_ptr) const
{
return operator()(
q_dram_block_window_tmp,
[](const QDataType& x) { return x; },
k_dram_block_window_tmp,
[](const KDataType& x) { return x; },
v_dram_block_window_tmp,
[](const VDataType& x) { return x; },
num_total_loop,
num_sub_loop_qk,
smem_ptr);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck {
namespace tile_program {
namespace block {
// This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQKVSDefaultPolicy
{
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeQLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kMPerBlock>{}, Number<8>{}),
make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return q_lds_block_desc;
}
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeKLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr ck::index_t GetSmemSizeQ()
{
constexpr index_t lds_alignment = 16; // optional
constexpr index_t q_smem_size =
ck::math::integer_divide_ceil(
sizeof(typename Problem::QDataType) *
MakeQLdsBlockDescriptor<Problem>().GetElementSpaceSize(),
lds_alignment) *
lds_alignment;
return q_smem_size;
}
template <typename Problem>
__host__ __device__ static constexpr ck::index_t GetSmemSize()
{
constexpr index_t smem_size_gemm_0 =
GetSmemSizeQ<Problem>() + sizeof(typename Problem::KDataType) *
MakeKLdsBlockDescriptor<Problem>().GetElementSpaceSize();
constexpr index_t smem_size_gemm_1 =
MakeVLdsBlockDescriptor<Problem>().GetElementSpaceSize() *
sizeof(typename Problem::VDataType);
// TODO: consider shuffle requirement
return math::max(smem_size_gemm_0, smem_size_gemm_1);
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeQDramTileDistribution()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<M0, M1, M2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = kMPerBlock / (M2 * M0);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<M0, M1, M2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<0>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<1, 1>>{});
#endif
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeKDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(KDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<0>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<1, 1>>{});
#endif
}
template <typename Problem>
__device__ static constexpr auto MakeVDramTileDistribution()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
template <typename Problem>
__host__ __device__ static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmASmemBSmemCRegV1Problem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmARegBSmemCRegV1Problem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace tile_program {
template <index_t kM0PerTile_, // tile size along q seqlen
index_t kN0PerTile_, // tile size along k seqlen
index_t kK0PerTile_, // tile size along qk gemm unroll
index_t kN1PerTile_, // tile size along v head_dim
index_t kK1PerTile_ // tile size along kv gemm unroll
>
struct TileFmhaShape
{
static constexpr index_t kM0 = kM0PerTile_;
static constexpr index_t kN0 = kN0PerTile_;
static constexpr index_t kK0 = kK0PerTile_;
static constexpr index_t kN1 = kN1PerTile_;
static constexpr index_t kK1 = kK1PerTile_;
};
} // namespace tile_program
} // namespace ck
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" #include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment