Commit 70514fd8 authored by danyao12's avatar danyao12
Browse files

bwd rtn

parent 17c97f58
...@@ -500,8 +500,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -500,8 +500,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if not cond: if not cond:
continue continue
if receipt == 3: if receipt == 3:
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['bf16']
cond &= bias in ['no', 'alibi'] cond &= bias in ['no']
cond &= dropout in ['no']
cond &= dpad == dvpad cond &= dpad == dvpad
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
......
...@@ -9,7 +9,7 @@ export CK_REPEAT=1 ...@@ -9,7 +9,7 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1' COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
......
...@@ -227,6 +227,32 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) ...@@ -227,6 +227,32 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
#endif #endif
} }
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_rtn_bf16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
for(index_t i = 0; i < thread_buffer_size; i++)
{
out_dstr_tensor.get_thread_buffer().at(i) =
float_to_bf16_raw<static_cast<bf16_rounding_mode>(0)>(
in_dstr_tensors.get_thread_buffer()[i]);
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword // this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
......
...@@ -125,6 +125,11 @@ struct BlockFmhaBwdConvertQGrad ...@@ -125,6 +125,11 @@ struct BlockFmhaBwdConvertQGrad
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) { sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2); constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
if constexpr(std::is_same_v<QGradDataType, bf16_t>)
dq_converted(n_i_j_idx) =
float_to_bf16_raw<static_cast<bf16_rounding_mode>(0)>(
dq_acc[n_i_j_idx]);
else
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]); dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
}); });
}); });
......
...@@ -611,18 +611,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -611,18 +611,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dropout.template Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window);
} }
const auto pt_gemm = [&]() { // const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
// }, pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const auto pt_dropped = [&]() {
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in([](const auto& x) { return x > 0.f ? x : 0.f; }, pt);
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return pt;
} }
}(); }();
const auto pt_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(pt_dropped);
else
return cast_tile<GemmDataType>(pt_dropped);
}();
// STAGE 3, P^T@OGrad^T Gemm1 // STAGE 3, P^T@OGrad^T Gemm1
auto do_block_tile = load_tile(do_dram_window); auto do_block_tile = load_tile(do_dram_window);
...@@ -702,7 +718,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -702,7 +718,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto qt_reg_tensor = load_tile(qt_lds_read_window); auto qt_reg_tensor = load_tile(qt_lds_read_window);
block_sync_lds(); block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst); // const auto dst_gemm = cast_tile<GemmDataType>(dst);
const auto dst_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(dst);
else
return cast_tile<GemmDataType>(dst);
}();
Policy::template SGradTFromGemm2CToGemm3A<Problem, Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor), decltype(dst_reg_tensor),
......
...@@ -647,18 +647,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -647,18 +647,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dropout.template Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window);
} }
const auto pt_gemm = [&]() { // const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
// }, pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const auto pt_dropped = [&]() {
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in([](const auto& x) { return x > 0.f ? x : 0.f; }, pt);
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return pt;
} }
}(); }();
const auto pt_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(pt_dropped);
else
return cast_tile<GemmDataType>(pt_dropped);
}();
// STAGE 3, P^T@OGrad^T Gemm1 // STAGE 3, P^T@OGrad^T Gemm1
Policy::template PTFromGemm0CToGemm1A<Problem, Policy::template PTFromGemm0CToGemm1A<Problem,
...@@ -733,7 +749,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -733,7 +749,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst); // const auto dst_gemm = cast_tile<GemmDataType>(dst);
const auto dst_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(dst);
else
return cast_tile<GemmDataType>(dst);
}();
Policy::template SGradTFromGemm2CToGemm3A<Problem, Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor), decltype(dst_reg_tensor),
...@@ -900,18 +922,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -900,18 +922,34 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
} }
// STAGE 3, P^T@OGrad^T Gemm1 // STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() { // const auto pt_gemm = [&]() {
// if constexpr(FmhaDropout::IsDropout)
// {
// return tile_elementwise_in(
// [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
// pt);
// }
// else
// {
// return cast_tile<GemmDataType>(pt);
// }
// }();
const auto pt_dropped = [&]() {
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in([](const auto& x) { return x > 0.f ? x : 0.f; }, pt);
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return pt;
} }
}(); }();
const auto pt_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(pt_dropped);
else
return cast_tile<GemmDataType>(pt_dropped);
}();
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(pt_gemm)>( Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(pt_gemm)>(
pt_reg_tensor, pt_gemm); pt_reg_tensor, pt_gemm);
...@@ -969,7 +1007,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -969,7 +1007,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst); // const auto dst_gemm = cast_tile<GemmDataType>(dst);
const auto dst_gemm = [&]() {
if constexpr(std::is_same_v<GemmDataType, bf16_t>)
return impl::cast_tile_rtn_bf16_fp32<GemmDataType>(dst);
else
return cast_tile<GemmDataType>(dst);
}();
Policy::template SGradTFromGemm2CToGemm3A<Problem, Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor), decltype(dst_reg_tensor),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment