Commit 8c967d76 authored by danyao12's avatar danyao12
Browse files

fix batch deterministic bugs

parent 74f1516c
...@@ -766,7 +766,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -766,7 +766,7 @@ struct FmhaBwdDQDKDVKernel
make_naive_tensor_view<address_space_enum::global>( make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.hdim_q, 1), make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{}, number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{}); number<1>{});
...@@ -1487,22 +1487,18 @@ struct FmhaBwdConvertQGradKernel ...@@ -1487,22 +1487,18 @@ struct FmhaBwdConvertQGradKernel
{ {
const AccDataType* dq_acc_ptr = const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) + reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.seqlen_q * kargs.hdim_q) + static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq;
batch_offset_dq;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
constexpr auto dq_fold = 4;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q / dq_fold, kargs.hdim_q * dq_fold), make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.hdim_q * dq_fold, 1), make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{}, number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{}); number<1>{});
return pad_tensor_view(dq_acc_dram_naive, return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<1>{}, make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
number<kM0 / dq_fold>{},
number<kQKHeaddim * dq_fold>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
...@@ -1538,12 +1534,10 @@ struct FmhaBwdConvertQGradKernel ...@@ -1538,12 +1534,10 @@ struct FmhaBwdConvertQGradKernel
auto dq_acc_dram_window = [&]() { auto dq_acc_dram_window = [&]() {
if constexpr(kIsDeterministic) if constexpr(kIsDeterministic)
{ {
constexpr auto dq_fold = 4; return make_tile_window(
return make_tile_window(dq_acc_dram, dq_acc_dram,
make_tuple(number<1>{}, make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
number<kM0 / dq_fold>{}, {0, i_m0, 0});
number<kQKHeaddim * dq_fold>{}),
{0, i_m0 / dq_fold, 0});
} }
else else
{ {
......
...@@ -52,7 +52,7 @@ struct BlockFmhaBwdConvertQGrad ...@@ -52,7 +52,7 @@ struct BlockFmhaBwdConvertQGrad
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(), dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(), dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>()); Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window); auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc); const auto dq = cast_tile<QGradDataType>(dq_acc);
...@@ -76,11 +76,11 @@ struct BlockFmhaBwdConvertQGrad ...@@ -76,11 +76,11 @@ struct BlockFmhaBwdConvertQGrad
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window = make_tile_window( auto dq_acc_dram_window =
dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(), dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(), dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>()); Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){}; auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -118,7 +118,7 @@ struct BlockFmhaBwdConvertQGrad ...@@ -118,7 +118,7 @@ struct BlockFmhaBwdConvertQGrad
// declare dq // declare dq
constexpr auto dq_converted_dstr = constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>(); Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr); auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
...@@ -130,8 +130,7 @@ struct BlockFmhaBwdConvertQGrad ...@@ -130,8 +130,7 @@ struct BlockFmhaBwdConvertQGrad
}); });
}); });
constexpr auto dq_dstr = constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
Policy::template MakePostQGradDeterministicDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr); auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer(); dq.get_thread_buffer() = dq_converted.get_thread_buffer();
......
...@@ -473,28 +473,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -473,28 +473,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dq_dram_block_window_tmp.get_window_lengths(), dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0});
// Deterministic mode staff
auto dq_buffer_view = dq_dram_block_window_tmp.get_bottom_tensor_view().get_buffer_view();
auto dq_tensor_desc =
dq_dram_block_window_tmp.get_bottom_tensor_view().get_tensor_descriptor();
auto seqlen_q = dq_tensor_desc.get_lengths()[number<0>{}];
auto hdim_q = dq_tensor_desc.get_lengths()[number<1>{}];
constexpr auto dq_fold = 4;
auto dq_write_tensor_desc =
make_naive_tensor_descriptor(make_tuple(seqlen_q / dq_fold, hdim_q * dq_fold),
make_tuple(hdim_q * dq_fold, 1),
number<kAlignmentQGrad>{},
number<1>{});
auto dq_tensor_view = tensor_view<decltype(dq_buffer_view), decltype(dq_write_tensor_desc)>{
dq_buffer_view, dq_write_tensor_desc};
auto dq_dram_window_deterministic =
make_tile_window(dq_tensor_view,
make_tuple(number<kM0 / dq_fold>{}, number<kQKHeaddim * dq_fold>{}),
{seqlen_q_start / dq_fold, 0});
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
...@@ -807,19 +785,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -807,19 +785,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
} }
if constexpr(kIsDeterministic) if constexpr(kIsDeterministic)
{ {
auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>( store_tile(dq_dram_window, dq_acc);
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
move_tile_window(dq_dram_window_deterministic, {kM0 / dq_fold, 0});
} }
else else
{ {
update_tile(dq_dram_window, dq_acc); update_tile(dq_dram_window, dq_acc);
move_tile_window(dq_dram_window, {kM0, 0});
} }
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1; i_total_loops += 1;
seqlen_q_step += kM0; seqlen_q_step += kM0;
...@@ -1047,12 +1019,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -1047,12 +1019,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
if constexpr(kIsDeterministic) if constexpr(kIsDeterministic)
{ {
auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>( store_tile(dq_dram_window, dq_acc);
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
} }
else else
{ {
......
...@@ -167,7 +167,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -167,7 +167,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>; false>;
using BlockGemmPolicy = using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType, BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
...@@ -534,91 +534,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -534,91 +534,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDeterministicDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType,
typename Problem::QGradDataType,
typename Problem::AccDataType,
Problem::Shape::WarpTile::at(number<0>{}),
Problem::Shape::WarpTile::at(number<1>{}),
Problem::Shape::WarpTile::at(number<2>{}),
true>;
using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<1>, sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 3>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 3>,
sequence<0, 0, 0>>{};
constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<1>,
sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
tuple<sequence<2, 3>>,
tuple<sequence<1, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
return dq_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDeterministicDramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType, using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
typename Problem::QGradDataType,
typename Problem::AccDataType,
Problem::Shape::WarpTile::at(number<0>{}),
Problem::Shape::WarpTile::at(number<1>{}),
Problem::Shape::WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0; constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr index_t K1 = 16 / sizeof(AccDataType);
dq_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); constexpr index_t K0 = kKPerBlock / K1;
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode); constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return dq_block_dstr; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 3>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
{ {
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
...@@ -1079,7 +1020,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1079,7 +1020,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>; false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
...@@ -1554,7 +1495,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1554,7 +1495,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>; false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
...@@ -1581,54 +1522,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1581,54 +1522,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return ds_block_dstr; return ds_block_dstr;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQGradWriteBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
return dq_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment