Unverified Commit 3753c4bc authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

Fmha pr 2 (#26)

* support hdim=64/128 in same example code

* support v transpose

* revert gemm.cpp, not intent to modify it

* remove useless code

* fix a bug for swizzle C encoding, no perf change

* optimize LDS encoding

* update LDS layout

* clean up code
parent e71aa1d6
...@@ -39,28 +39,128 @@ using ODataType = ck::half_t; ...@@ -39,28 +39,128 @@ using ODataType = ck::half_t;
// M0 N0 K0 N1 K1 K0L // M0 N0 K0 N1 K1 K0L
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using FmhaBlockTile = ck::Sequence<128, 128, 32, 128, 32, 128>; using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim
using FmhaBlockWarps = ck::Sequence<4, 1, 1>; // using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen
using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaShape = ck::tile_program:: using FmhaBlockTileHdim64 = ck::Sequence<128, 64, 32, 64, 32, 64>;
TileFmhaShape<FmhaBlockTile, FmhaBlockWarps, FmhaWarpTile, FmhaBlockWarps, FmhaWarpTile>; using FmhaBlockTileHdim128 = ck::Sequence<128, 128, 32, 128, 32, 128>;
using FmhaBlockWarps = ck::Sequence<4, 1, 1>;
using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>; using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType, using FmhaShapeHDim64 = ck::tile_program::TileFmhaShape<FmhaBlockTileHdim64,
KDataType, FmhaBlockWarps,
VDataType, FmhaWarpTile,
SaccDataType, FmhaBlockWarps,
SMPLComputeDataType, FmhaWarpTile,
PDataType, VLayout>;
OaccDataType, using FmhaShapeHDim128 = ck::tile_program::TileFmhaShape<FmhaBlockTileHdim128,
ODataType, FmhaBlockWarps,
256, // BlockSize FmhaWarpTile,
FmhaShape>; FmhaBlockWarps,
FmhaWarpTile,
VLayout>;
using FmhaTilePartitionerHDim64 = FmhaFwdTilePartitioner<FmhaShapeHDim64>;
using FmhaTilePartitionerHDim128 = FmhaFwdTilePartitioner<FmhaShapeHDim128>;
using FmhaPipelineProblemHDim64 =
ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShapeHDim64>;
using FmhaPipelineProblemHDim128 =
ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
256, // BlockSize
FmhaShapeHDim128>;
// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>; // using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>; using FmhaPipelineHDim64 =
ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblemHDim64>;
using FmhaPipelineHDim128 =
ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblemHDim128>;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernelHDim64 = FmhaFwdKernel<FmhaTilePartitionerHDim64, FmhaPipelineHDim64, FmhaEpilogue>;
using FmhaKernelHDim128 =
FmhaFwdKernel<FmhaTilePartitionerHDim128, FmhaPipelineHDim128, FmhaEpilogue>;
template <typename FmhaKernel>
float invoker_fmha_kernel(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck::index_t batch,
ck::index_t nhead,
ck::index_t seqlen_q,
ck::index_t seqlen_k,
ck::index_t hdim_q,
ck::index_t hdim_v,
float scale,
bool i_perm,
bool o_perm)
{
dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>; constexpr bool is_v_rowmajor =
using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>; ck::is_same_v<typename FmhaKernel::VLayout, ck::tensor_layout::gemm::RowMajor>;
// batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim
auto kargs = FmhaKernel::MakeKargs(
q_ptr,
k_ptr,
v_ptr,
o_ptr,
seqlen_q, // seqlen_q
seqlen_k, // seqlen_k
hdim_q, // hdim_q
hdim_v, // hdim_v
scale,
i_perm ? hdim_q : nhead * hdim_q, // stride_q
i_perm ? hdim_q : nhead * hdim_q, // stride_k
[&]() {
if constexpr(is_v_rowmajor)
return i_perm ? hdim_v : nhead * hdim_v;
else
return i_perm ? seqlen_k : nhead * seqlen_k;
}(), // stride_v
o_perm ? hdim_v : nhead * hdim_v, // stride_o
i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q
i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k
[&]() {
if constexpr(is_v_rowmajor)
return i_perm ? seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_k : seqlen_k;
}(), // nhead_stride_v
o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o
nhead * seqlen_q * hdim_q, // batch_stride_q
nhead * seqlen_k * hdim_q, // batch_stride_k
nhead * hdim_v * seqlen_k, // batch_stride_v
nhead * seqlen_q * hdim_v); // batch_stride_o
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true},
FmhaKernel{},
kGridSize,
kBlockSize,
0,
kargs); // BatchStrideO
return ave_time;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -110,10 +210,14 @@ int main(int argc, char* argv[]) ...@@ -110,10 +210,14 @@ int main(int argc, char* argv[])
return std::array<ck::index_t, 4>{b, s, h, d}; return std::array<ck::index_t, 4>{b, s, h, d};
}; };
constexpr bool is_v_rowmajor =
ck::is_same_v<typename FmhaKernelHDim64::VLayout, ck::tensor_layout::gemm::RowMajor>;
// host verify // host verify
Tensor<QDataType> q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q)); Tensor<QDataType> q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q));
Tensor<KDataType> k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q)); Tensor<KDataType> k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q));
Tensor<VDataType> v_host(get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k)); Tensor<VDataType> v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead, seqlen_k, hdim_v)
: get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k));
Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
#if 0 #if 0
...@@ -135,48 +239,45 @@ int main(int argc, char* argv[]) ...@@ -135,48 +239,45 @@ int main(int argc, char* argv[])
k_buf.ToDevice(k_host.mData.data()); k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data()); v_buf.ToDevice(v_host.mData.data());
dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q
<< ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v
<< ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm << ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm
<< ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z << ", v:" << std::string(FmhaKernelHDim64::VLayout::name) << std::flush << std::endl;
<< std::flush << std::endl;
float ave_time = 0;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD if(hdim_q == hdim_v && hdim_q == 64)
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; ave_time = invoker_fmha_kernel<FmhaKernelHDim64>(q_buf.GetDeviceBuffer(),
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
// batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim o_buf.GetDeviceBuffer(),
auto kargs = FmhaKernel::MakeKargs(q_buf.GetDeviceBuffer(), batch,
k_buf.GetDeviceBuffer(), nhead,
v_buf.GetDeviceBuffer(), seqlen_q,
o_buf.GetDeviceBuffer(), seqlen_k,
seqlen_q, // seqlen_q hdim_q,
seqlen_k, // seqlen_k hdim_v,
hdim_q, // hdim_q scale,
hdim_v, // hdim_v i_perm,
scale, o_perm);
i_perm ? hdim_q : nhead * hdim_q, // stride_q else if(hdim_q == hdim_v && hdim_q == 128)
i_perm ? hdim_q : nhead * hdim_q, // stride_k ave_time = invoker_fmha_kernel<FmhaKernelHDim128>(q_buf.GetDeviceBuffer(),
i_perm ? seqlen_k : nhead * seqlen_k, // stride_v k_buf.GetDeviceBuffer(),
o_perm ? hdim_v : nhead * hdim_v, // stride_o v_buf.GetDeviceBuffer(),
i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q o_buf.GetDeviceBuffer(),
i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k batch,
i_perm ? hdim_v * seqlen_k : seqlen_k, // nhead_stride_v nhead,
o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o seqlen_q,
nhead * seqlen_q * hdim_q, // batch_stride_q seqlen_k,
nhead * seqlen_k * hdim_q, // batch_stride_k hdim_q,
nhead * hdim_v * seqlen_k, // batch_stride_v hdim_v,
nhead * seqlen_q * hdim_v); // batch_stride_o scale,
i_perm,
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true}, o_perm);
FmhaKernel{}, else
kGridSize, {
kBlockSize, std::cout << "not support hdim, will not run" << std::endl;
0, return -1;
kargs); // BatchStrideO }
std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q + std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q +
std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k; std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k;
...@@ -197,7 +298,11 @@ int main(int argc, char* argv[]) ...@@ -197,7 +298,11 @@ int main(int argc, char* argv[])
{ {
Tensor<QDataType> q_host_ref({batch * nhead, seqlen_q, hdim_q}); Tensor<QDataType> q_host_ref({batch * nhead, seqlen_q, hdim_q});
Tensor<KDataType> k_host_ref({batch * nhead, seqlen_k, hdim_q}); Tensor<KDataType> k_host_ref({batch * nhead, seqlen_k, hdim_q});
Tensor<VDataType> v_host_ref({batch * nhead, hdim_v, seqlen_k}); const auto v_lengths = std::array<ck::index_t, 3>{batch * nhead, hdim_v, seqlen_k};
const auto v_strides = is_v_rowmajor
? std::array<ck::index_t, 3>{hdim_v * seqlen_k, 1, hdim_v}
: std::array<ck::index_t, 3>{hdim_v * seqlen_k, seqlen_k, 1};
Tensor<VDataType> v_host_ref(v_lengths, v_strides);
Tensor<ODataType> o_host_ref({batch * nhead, seqlen_q, hdim_v}); Tensor<ODataType> o_host_ref({batch * nhead, seqlen_q, hdim_v});
Tensor<ODataType> o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); Tensor<ODataType> o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
...@@ -212,8 +317,16 @@ int main(int argc, char* argv[]) ...@@ -212,8 +317,16 @@ int main(int argc, char* argv[])
if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); if constexpr (is_v_rowmajor) {
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); // v_host :b, h, s, d, v_host_ref : batch*hdim*seq
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[3], idx[2]) = self(idx); });
// v_host : b, s, h, d, v_host_ref : batch*hdim*seq
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[3], idx[1]) = self(idx); });
}
else {
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
}
// reference // reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>( reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
......
...@@ -26,6 +26,8 @@ struct FmhaFwdKernel ...@@ -26,6 +26,8 @@ struct FmhaFwdKernel
using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>; using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>;
using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>; using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck::remove_cvref_t<typename FmhaPipeline::VLayout>;
struct Kargs struct Kargs
{ {
const void* q_ptr; const void* q_ptr;
...@@ -126,7 +128,6 @@ struct FmhaFwdKernel ...@@ -126,7 +128,6 @@ struct FmhaFwdKernel
i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o; i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o;
// Q/K/V DRAM and DRAM window // Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>( const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
q_ptr, q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.seqlen_q, kargs.hdim_q),
...@@ -141,12 +142,32 @@ struct FmhaFwdKernel ...@@ -141,12 +142,32 @@ struct FmhaFwdKernel
Number<32>{}, Number<32>{},
Number<1>{}); Number<1>{});
const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>( const auto v_dram = [&]() {
v_ptr, if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
make_tuple(kargs.hdim_v, kargs.seqlen_k), {
make_tuple(kargs.stride_v, 1), const auto v_dram_tmp = make_naive_tensor_view<AddressSpaceEnum::Global>(
Number<32>{}, v_ptr,
Number<1>{}); make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
return transform_tensor_view(
v_dram_tmp,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<1>{});
}
}();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp"
#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp"
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
...@@ -36,6 +37,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -36,6 +37,7 @@ struct BlockFmhaPipelineQRKSVS
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -263,8 +265,20 @@ struct BlockFmhaPipelineQRKSVS ...@@ -263,8 +265,20 @@ struct BlockFmhaPipelineQRKSVS
}); });
block_sync_lds(); block_sync_lds();
store_tile(v_lds_window, if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
const auto p = const auto p =
...@@ -282,11 +296,26 @@ struct BlockFmhaPipelineQRKSVS ...@@ -282,11 +296,26 @@ struct BlockFmhaPipelineQRKSVS
p, Sequence<0, i_k1 * kK1>{}, Sequence<kM0, (i_k1 + 1) * kK1>{}), p, Sequence<0, i_k1 * kK1>{}, Sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window); v_lds_window);
block_sync_lds(); block_sync_lds();
store_tile(v_lds_window, if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
tile_elementwise_in(v_element_func, v)); // store next v {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_distributed_tensor(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
} }
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
...@@ -295,10 +324,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -295,10 +324,6 @@ struct BlockFmhaPipelineQRKSVS
v_lds_window); v_lds_window);
block_sync_lds(); block_sync_lds();
} }
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
i_total_loops++;
} while(i_total_loops < num_total_loop); } while(i_total_loops < num_total_loop);
// finally, O // finally, O
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
...@@ -23,6 +24,27 @@ namespace block { ...@@ -23,6 +24,27 @@ namespace block {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQRKSVSDefaultPolicy struct BlockFmhaPipelineQRKSVSDefaultPolicy
{ {
template <typename Problem>
__host__ __device__ static constexpr auto GetSmemKPackK()
{
// TODO: this is for 3d layout
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
__host__ __device__ static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
__host__ __device__ static constexpr auto GetTransposedVectorloadV()
{
return 4; // TODO: fix me
}
template <typename Problem, typename BlockGemm> template <typename Problem, typename BlockGemm>
__host__ __device__ static constexpr auto MakeQRegBlockDescriptor() __host__ __device__ static constexpr auto MakeQRegBlockDescriptor()
{ {
...@@ -61,17 +83,18 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy ...@@ -61,17 +83,18 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}), make_tuple(Number<kKPerBlock / kKPack>{}, Number<kNPerBlock>{}, Number<kKPack>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number<kKPack>{}, Number<1>{}),
Number<8>{}, Number<8>{},
Number<1>{}); Number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor( constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0, k_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock), make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))), make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -82,25 +105,60 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy ...@@ -82,25 +105,60 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
template <typename Problem> template <typename Problem>
__host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor()
{ {
#if 0
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kPad = 1; constexpr index_t kPad = 1;
constexpr index_t kK1 = 8; constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}), make_tuple(Number<kKPerBlock / kKPack>{}, Number<kNPerBlock>{}, Number<kKPack>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}), make_tuple(Number<(kNPerBlock + kPad) * kKPack>{}, Number<kKPack>{}, Number<1>{}),
Number<kK1>{}, Number<kKPack>{},
Number<1>{}); Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor( constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0, v_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock), make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))), make_merge_transform(make_tuple(Number<kKPerBlock / kKPack>{}, Number<kKPack>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc; return v_lds_block_desc;
#else
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kKPack>{},
Number<kNPerBlock / NPerRow>{},
Number<NPerRow>{},
Number<kKPack>{}),
make_tuple(Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
Number<PixelsPerRow + kKPack>{},
Number<kKPack>{},
Number<1>{}),
Number<kKPack>{},
Number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(Number<kNPerBlock / NPerRow>{}, Number<NPerRow>{})),
make_merge_transform(make_tuple(Number<kKPerBlock / kKPack>{}, Number<kKPack>{}))),
make_tuple(Sequence<1, 2>{}, Sequence<0, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return v_lds_block_desc;
#endif
} }
template <typename Problem> template <typename Problem>
...@@ -192,25 +250,81 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy ...@@ -192,25 +250,81 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
__device__ static constexpr auto MakeVDramTileDistribution() __device__ static constexpr auto MakeVDramTileDistribution()
{ {
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
; using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t K1 = 16 / sizeof(VDataType); if constexpr(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>)
constexpr index_t K0 = kKPerBlock / K1; {
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N1 = GetTransposedVectorloadV<Problem>();
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t N0 = kNPerBlock / (N2 * N1);
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1>, Sequence<K0, K1, K2, K3>>,
Tuple<Sequence<2>, Sequence<2, 1, 2>>,
Tuple<Sequence<0>, Sequence<1, 0, 2>>,
Sequence<2, 1>,
Sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor()
{
// This descriptor only used when V layout is seqlen * hdim
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetTransposedVectorloadV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>, StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>, Tuple<Sequence<N0, N1>, Sequence<K0, K1, K2, K3>>,
Tuple<Sequence<1>, Sequence<1, 2>>, Tuple<Sequence<2>, Sequence<2, 1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>, Tuple<Sequence<0>, Sequence<1, 0, 2>>,
Sequence<1, 2>, Sequence<1, 2>,
Sequence<0, 1>>{}); Sequence<1, 3>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -224,11 +338,6 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy ...@@ -224,11 +338,6 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
TileGemmShape<Problem::BlockFmhaShape::kM0, TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>; Problem::BlockFmhaShape::kK0>>;
// using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<typename
// Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType,
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<0>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>;
using WarpGemm = warp::WarpGemmImpl< using WarpGemm = warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
...@@ -256,7 +365,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy ...@@ -256,7 +365,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy
TileGemmShape<Problem::BlockFmhaShape::kM0, TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>; Problem::BlockFmhaShape::kK1>>;
// using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<
typename Problem::PDataType, typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck { namespace ck {
namespace tile_program { namespace tile_program {
...@@ -12,7 +13,8 @@ template <typename BlockTile_, // Sequence<... ...@@ -12,7 +13,8 @@ template <typename BlockTile_, // Sequence<...
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
typename Gemm1BlockWarps_, typename Gemm1BlockWarps_,
typename Gemm1WarpTile_> typename Gemm1WarpTile_,
typename VLayout_ = ck::tensor_layout::gemm::RowMajor>
struct TileFmhaShape struct TileFmhaShape
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
...@@ -29,6 +31,8 @@ struct TileFmhaShape ...@@ -29,6 +31,8 @@ struct TileFmhaShape
static constexpr index_t kK0BlockLength = static constexpr index_t kK0BlockLength =
BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile) // once (or repeately load Q as a whole tile)
using VLayout = remove_cvref_t<VLayout_>; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen
}; };
} // namespace tile_program } // namespace tile_program
......
...@@ -328,7 +328,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -328,7 +328,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
using CWarpDstrEncoding = StaticTileDistributionEncoding< using CWarpDstrEncoding = StaticTileDistributionEncoding<
Sequence<>, Sequence<>,
Tuple<Sequence<Impl::kCNLane>, Tuple<Sequence<Impl::kCNLane>,
Sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>, Sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
Tuple<Sequence<2, 1>>, Tuple<Sequence<2, 1>>,
Tuple<Sequence<1, 0>>, Tuple<Sequence<1, 0>>,
Sequence<2, 2>, Sequence<2, 2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment