"vscode:/vscode.git/clone" did not exist on "be4e3133f74daaabed839edb05c37ce3beae54a9"
Commit 63c2d069 authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-dropout-pt1

parents a5bad9f2 82ce7f4e
...@@ -401,9 +401,9 @@ int run(int argc, char* argv[]) ...@@ -401,9 +401,9 @@ int run(int argc, char* argv[])
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
......
...@@ -44,7 +44,7 @@ struct BlockwiseDropout ...@@ -44,7 +44,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -79,7 +79,7 @@ struct BlockwiseDropout ...@@ -79,7 +79,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
......
...@@ -1191,7 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1191,7 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1866,8 +1865,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1866,8 +1865,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global // save z to global
if(is_dropout)
{
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
...@@ -1876,8 +1873,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1876,8 +1873,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true>( true>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1889,7 +1885,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1889,7 +1885,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
...@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0; arg.ref_(idx) <= arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
}); });
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment