Commit 8c967d76 authored by danyao12's avatar danyao12
Browse files

fix batch deterministic bugs

parent 74f1516c
......@@ -766,7 +766,7 @@ struct FmhaBwdDQDKDVKernel
make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.hdim_q, 1),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
......@@ -1487,22 +1487,18 @@ struct FmhaBwdConvertQGradKernel
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.seqlen_q * kargs.hdim_q) +
batch_offset_dq;
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
constexpr auto dq_fold = 4;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q / dq_fold, kargs.hdim_q * dq_fold),
make_tuple(kargs.split_stride_dq_acc, kargs.hdim_q * dq_fold, 1),
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<1>{},
number<kM0 / dq_fold>{},
number<kQKHeaddim * dq_fold>{}),
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimQ>{});
}
else
......@@ -1538,12 +1534,10 @@ struct FmhaBwdConvertQGradKernel
auto dq_acc_dram_window = [&]() {
if constexpr(kIsDeterministic)
{
constexpr auto dq_fold = 4;
return make_tile_window(dq_acc_dram,
make_tuple(number<1>{},
number<kM0 / dq_fold>{},
number<kQKHeaddim * dq_fold>{}),
{0, i_m0 / dq_fold, 0});
return make_tile_window(
dq_acc_dram,
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
{0, i_m0, 0});
}
else
{
......
......@@ -52,7 +52,7 @@ struct BlockFmhaBwdConvertQGrad
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
......@@ -76,11 +76,11 @@ struct BlockFmhaBwdConvertQGrad
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window = make_tile_window(
dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>());
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
......@@ -118,7 +118,7 @@ struct BlockFmhaBwdConvertQGrad
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>();
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
......@@ -130,8 +130,7 @@ struct BlockFmhaBwdConvertQGrad
});
});
constexpr auto dq_dstr =
Policy::template MakePostQGradDeterministicDramTileDistribution<Problem>();
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
......
......@@ -473,28 +473,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
// Deterministic mode staff
auto dq_buffer_view = dq_dram_block_window_tmp.get_bottom_tensor_view().get_buffer_view();
auto dq_tensor_desc =
dq_dram_block_window_tmp.get_bottom_tensor_view().get_tensor_descriptor();
auto seqlen_q = dq_tensor_desc.get_lengths()[number<0>{}];
auto hdim_q = dq_tensor_desc.get_lengths()[number<1>{}];
constexpr auto dq_fold = 4;
auto dq_write_tensor_desc =
make_naive_tensor_descriptor(make_tuple(seqlen_q / dq_fold, hdim_q * dq_fold),
make_tuple(hdim_q * dq_fold, 1),
number<kAlignmentQGrad>{},
number<1>{});
auto dq_tensor_view = tensor_view<decltype(dq_buffer_view), decltype(dq_write_tensor_desc)>{
dq_buffer_view, dq_write_tensor_desc};
auto dq_dram_window_deterministic =
make_tile_window(dq_tensor_view,
make_tuple(number<kM0 / dq_fold>{}, number<kQKHeaddim * dq_fold>{}),
{seqlen_q_start / dq_fold, 0});
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
......@@ -807,19 +785,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
if constexpr(kIsDeterministic)
{
auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>(
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
move_tile_window(dq_dram_window_deterministic, {kM0 / dq_fold, 0});
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
move_tile_window(dq_dram_window, {kM0, 0});
}
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1;
seqlen_q_step += kM0;
......@@ -1047,12 +1019,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
if constexpr(kIsDeterministic)
{
auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>(
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
store_tile(dq_dram_window, dq_acc);
}
else
{
......
......@@ -167,7 +167,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
......@@ -534,91 +534,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDeterministicDramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType,
typename Problem::QGradDataType,
typename Problem::AccDataType,
Problem::Shape::WarpTile::at(number<0>{}),
Problem::Shape::WarpTile::at(number<1>{}),
Problem::Shape::WarpTile::at(number<2>{}),
true>;
using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<1>, sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 3>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 3>,
sequence<0, 0, 0>>{};
constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<1>,
sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
tuple<sequence<2, 3>>,
tuple<sequence<1, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
return dq_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDeterministicDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType,
typename Problem::QGradDataType,
typename Problem::AccDataType,
Problem::Shape::WarpTile::at(number<0>{}),
Problem::Shape::WarpTile::at(number<1>{}),
Problem::Shape::WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim;
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return dq_block_dstr;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 3>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
......@@ -1079,7 +1020,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
......@@ -1554,7 +1495,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
......@@ -1581,54 +1522,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return ds_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQGradWriteBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
return dq_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment