"tests/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "9efcac38af58b7247e205c47efe090b4c6ec7574"
Unverified Commit 3753c4bc authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

Fmha pr 2 (#26)

* support hdim=64/128 in same example code

* support v transpose

* revert gemm.cpp, not intent to modify it

* remove useless code

* fix a bug for swizzle C encoding, no perf change

* optimize LDS encoding

* update LDS layout

* clean up code
parent e71aa1d6
......@@ -39,28 +39,128 @@ using ODataType = ck::half_t;
// M0 N0 K0 N1 K1 K0L
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using FmhaBlockTile = ck::Sequence<128, 128, 32, 128, 32, 128>;
using FmhaBlockWarps = ck::Sequence<4, 1, 1>;
using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaShape = ck::tile_program::
TileFmhaShape<FmhaBlockTile, FmhaBlockWarps, FmhaWarpTile, FmhaBlockWarps, FmhaWarpTile>;
using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShape>;
using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim
// using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen
using FmhaBlockTileHdim64 = ck::Sequence<128, 64, 32, 64, 32, 64>;
using FmhaBlockTileHdim128 = ck::Sequence<128, 128, 32, 128, 32, 128>;
using FmhaBlockWarps = ck::Sequence<4, 1, 1>;
using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaShapeHDim64 = ck::tile_program::TileFmhaShape<FmhaBlockTileHdim64,
FmhaBlockWarps,
FmhaWarpTile,
FmhaBlockWarps,
FmhaWarpTile,
VLayout>;
using FmhaShapeHDim128 = ck::tile_program::TileFmhaShape<FmhaBlockTileHdim128,
FmhaBlockWarps,
FmhaWarpTile,
FmhaBlockWarps,
FmhaWarpTile,
VLayout>;
using FmhaTilePartitionerHDim64 = FmhaFwdTilePartitioner<FmhaShapeHDim64>;
using FmhaTilePartitionerHDim128 = FmhaFwdTilePartitioner<FmhaShapeHDim128>;
using FmhaPipelineProblemHDim64 =
ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShapeHDim64>;
using FmhaPipelineProblemHDim128 =
ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShapeHDim128>;
// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>;
using FmhaPipelineHDim64 =
ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblemHDim64>;
using FmhaPipelineHDim128 =
ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblemHDim128>;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernelHDim64 = FmhaFwdKernel<FmhaTilePartitionerHDim64, FmhaPipelineHDim64, FmhaEpilogue>;
using FmhaKernelHDim128 =
FmhaFwdKernel<FmhaTilePartitionerHDim128, FmhaPipelineHDim128, FmhaEpilogue>;
template <typename FmhaKernel>
float invoker_fmha_kernel(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck::index_t batch,
ck::index_t nhead,
ck::index_t seqlen_q,
ck::index_t seqlen_k,
ck::index_t hdim_q,
ck::index_t hdim_v,
float scale,
bool i_perm,
bool o_perm)
{
dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
constexpr bool is_v_rowmajor =
ck::is_same_v<typename FmhaKernel::VLayout, ck::tensor_layout::gemm::RowMajor>;
// batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim
auto kargs = FmhaKernel::MakeKargs(
q_ptr,
k_ptr,
v_ptr,
o_ptr,
seqlen_q, // seqlen_q
seqlen_k, // seqlen_k
hdim_q, // hdim_q
hdim_v, // hdim_v
scale,
i_perm ? hdim_q : nhead * hdim_q, // stride_q
i_perm ? hdim_q : nhead * hdim_q, // stride_k
[&]() {
if constexpr(is_v_rowmajor)
return i_perm ? hdim_v : nhead * hdim_v;
else
return i_perm ? seqlen_k : nhead * seqlen_k;
}(), // stride_v
o_perm ? hdim_v : nhead * hdim_v, // stride_o
i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q
i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k
[&]() {
if constexpr(is_v_rowmajor)
return i_perm ? seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_k : seqlen_k;
}(), // nhead_stride_v
o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o
nhead * seqlen_q * hdim_q, // batch_stride_q
nhead * seqlen_k * hdim_q, // batch_stride_k
nhead * hdim_v * seqlen_k, // batch_stride_v
nhead * seqlen_q * hdim_v); // batch_stride_o
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true},
FmhaKernel{},
kGridSize,
kBlockSize,
0,
kargs); // BatchStrideO
return ave_time;
}
int main(int argc, char* argv[])
{
......@@ -110,10 +210,14 @@ int main(int argc, char* argv[])
return std::array<ck::index_t, 4>{b, s, h, d};
};
constexpr bool is_v_rowmajor =
ck::is_same_v<typename FmhaKernelHDim64::VLayout, ck::tensor_layout::gemm::RowMajor>;
// host verify
Tensor<QDataType> q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q));
Tensor<KDataType> k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q));
Tensor<VDataType> v_host(get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k));
Tensor<VDataType> v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead, seqlen_k, hdim_v)
: get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k));
Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
#if 0
......@@ -135,48 +239,45 @@ int main(int argc, char* argv[])
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q
<< ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v
<< ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm
<< ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z
<< std::flush << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
// batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim
auto kargs = FmhaKernel::MakeKargs(q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqlen_q, // seqlen_q
seqlen_k, // seqlen_k
hdim_q, // hdim_q
hdim_v, // hdim_v
scale,
i_perm ? hdim_q : nhead * hdim_q, // stride_q
i_perm ? hdim_q : nhead * hdim_q, // stride_k
i_perm ? seqlen_k : nhead * seqlen_k, // stride_v
o_perm ? hdim_v : nhead * hdim_v, // stride_o
i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q
i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k
i_perm ? hdim_v * seqlen_k : seqlen_k, // nhead_stride_v
o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o
nhead * seqlen_q * hdim_q, // batch_stride_q
nhead * seqlen_k * hdim_q, // batch_stride_k
nhead * hdim_v * seqlen_k, // batch_stride_v
nhead * seqlen_q * hdim_v); // batch_stride_o
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true},
FmhaKernel{},
kGridSize,
kBlockSize,
0,
kargs); // BatchStrideO
<< ", v:" << std::string(FmhaKernelHDim64::VLayout::name) << std::flush << std::endl;
float ave_time = 0;
if(hdim_q == hdim_v && hdim_q == 64)
ave_time = invoker_fmha_kernel<FmhaKernelHDim64>(q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
batch,
nhead,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
scale,
i_perm,
o_perm);
else if(hdim_q == hdim_v && hdim_q == 128)
ave_time = invoker_fmha_kernel<FmhaKernelHDim128>(q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
batch,
nhead,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
scale,
i_perm,
o_perm);
else
{
std::cout << "not support hdim, will not run" << std::endl;
return -1;
}
std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q +
std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k;
......@@ -197,7 +298,11 @@ int main(int argc, char* argv[])
{
Tensor<QDataType> q_host_ref({batch * nhead, seqlen_q, hdim_q});
Tensor<KDataType> k_host_ref({batch * nhead, seqlen_k, hdim_q});
Tensor<VDataType> v_host_ref({batch * nhead, hdim_v, seqlen_k});
const auto v_lengths = std::array<ck::index_t, 3>{batch * nhead, hdim_v, seqlen_k};
const auto v_strides = is_v_rowmajor
? std::array<ck::index_t, 3>{hdim_v * seqlen_k, 1, hdim_v}
: std::array<ck::index_t, 3>{hdim_v * seqlen_k, seqlen_k, 1};
Tensor<VDataType> v_host_ref(v_lengths, v_strides);
Tensor<ODataType> o_host_ref({batch * nhead, seqlen_q, hdim_v});
Tensor<ODataType> o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
......@@ -212,8 +317,16 @@ int main(int argc, char* argv[])
if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
if constexpr (is_v_rowmajor) {
// v_host :b, h, s, d, v_host_ref : batch*hdim*seq
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[3], idx[2]) = self(idx); });
// v_host : b, s, h, d, v_host_ref : batch*hdim*seq
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[3], idx[1]) = self(idx); });
}
else {
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
}
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
......
......@@ -26,6 +26,8 @@ struct FmhaFwdKernel
using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>;
using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck::remove_cvref_t<typename FmhaPipeline::VLayout>;
struct Kargs
{
const void* q_ptr;
......@@ -126,7 +128,6 @@ struct FmhaFwdKernel
i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o;
// Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
......@@ -141,12 +142,32 @@ struct FmhaFwdKernel
Number<32>{},
Number<1>{});
const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
const auto v_dram = [&]() {
if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_tmp = make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
return transform_tensor_view(
v_dram_tmp,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
......
......@@ -17,6 +17,7 @@
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
......@@ -36,6 +37,7 @@ struct BlockFmhaPipelineQRKSVS
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once
static constexpr index_t kBlockSize = Problem::kBlockSize;
......@@ -263,8 +265,20 @@ struct BlockFmhaPipelineQRKSVS
});
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
......@@ -282,11 +296,26 @@ struct BlockFmhaPipelineQRKSVS
p, Sequence<0, i_k1 * kK1>{}, Sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_distributed_tensor(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
move_tile_window(v_dram_window, {0, kK1});
});
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
// tail
{
block_sync_lds();
......@@ -295,10 +324,6 @@ struct BlockFmhaPipelineQRKSVS
v_lds_window);
block_sync_lds();
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
} while(i_total_loops < num_total_loop);
// finally, O
......
......@@ -15,6 +15,7 @@
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck {
namespace tile_program {
......@@ -23,6 +24,27 @@ namespace block {
// This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQRKSVSDefaultPolicy
{
template <typename Problem>
__host__ __device__ static constexpr auto GetSmemKPackK()
{
// TODO: this is for 3d layout
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
__host__ __device__ static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
__host__ __device__ static constexpr auto GetTransposedVectorloadV()
{
return 4; // TODO: fix me
}
template <typename Problem, typename BlockGemm>
__host__ __device__ static constexpr auto MakeQRegBlockDescriptor()
{
......@@ -61,17 +83,18 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
make_tuple(Number<kKPerBlock / kKPack>{}, Number<kNPerBlock>{}, Number<kKPack>{}),
make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number<kKPack>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -82,25 +105,60 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
template <typename Problem>
__host__ __device__ static constexpr auto MakeVLdsBlockDescriptor()
{
#if 0
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
make_tuple(Number<kKPerBlock / kKPack>{}, Number<kNPerBlock>{}, Number<kKPack>{}),
make_tuple(Number<(kNPerBlock + kPad) * kKPack>{}, Number<kKPack>{}, Number<1>{}),
Number<kKPack>{},
Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_merge_transform(make_tuple(Number<kKPerBlock / kKPack>{}, Number<kKPack>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc;
#else
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kKPack>{},
Number<kNPerBlock / NPerRow>{},
Number<NPerRow>{},
Number<kKPack>{}),
make_tuple(Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
Number<PixelsPerRow + kKPack>{},
Number<kKPack>{},
Number<1>{}),
Number<kKPack>{},
Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(Number<kNPerBlock / NPerRow>{}, Number<NPerRow>{})),
make_merge_transform(make_tuple(Number<kKPerBlock / kKPack>{}, Number<kKPack>{}))),
make_tuple(Sequence<1, 2>{}, Sequence<0, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc;
#endif
}
template <typename Problem>
......@@ -192,25 +250,81 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
__device__ static constexpr auto MakeVDramTileDistribution()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
;
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetTransposedVectorloadV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1>, Sequence<K0, K1, K2, K3>>,
Tuple<Sequence<2>, Sequence<2, 1, 2>>,
Tuple<Sequence<0>, Sequence<1, 0, 2>>,
Sequence<2, 1>,
Sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor()
{
// This descriptor only used when V layout is seqlen * hdim
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetTransposedVectorloadV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Tuple<Sequence<N0, N1>, Sequence<K0, K1, K2, K3>>,
Tuple<Sequence<2>, Sequence<2, 1, 2>>,
Tuple<Sequence<0>, Sequence<1, 0, 2>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
Sequence<1, 3>>{});
}
template <typename Problem>
......@@ -224,11 +338,6 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
// using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<typename
// Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType,
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<0>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>;
using WarpGemm = warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
......@@ -256,7 +365,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
// using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck {
namespace tile_program {
......@@ -12,7 +13,8 @@ template <typename BlockTile_, // Sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_>
typename Gemm1WarpTile_,
typename VLayout_ = ck::tensor_layout::gemm::RowMajor>
struct TileFmhaShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
......@@ -29,6 +31,8 @@ struct TileFmhaShape
static constexpr index_t kK0BlockLength =
BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
using VLayout = remove_cvref_t<VLayout_>; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen
};
} // namespace tile_program
......
......@@ -328,7 +328,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<Impl::kCNLane>,
Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
Sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
Tuple<Sequence<2, 1>>,
Tuple<Sequence<1, 0>>,
Sequence<2, 2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment