Commit 272b7574 authored by danyao12's avatar danyao12
Browse files

fix drop==0 compiler issue in prototype1

parent 63c2d069
...@@ -400,9 +400,9 @@ int run(int argc, char* argv[]) ...@@ -400,9 +400,9 @@ int run(int argc, char* argv[])
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
......
...@@ -1265,7 +1265,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1265,7 +1265,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1718,36 +1717,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1718,36 +1717,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) + block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx; y_thread_data_on_block_idx;
// // performs for y
// auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
// decltype(y_thread_desc_m0_m1_o0_o1),
// decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
// Sequence<0, 1, 2, 3>,
// 3, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
// y_thread_data_on_grid_idx);
// // performs for ygrad
// auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
// decltype(ygrad_thread_desc_m_o),
// decltype(ygrad_thread_desc_m_o.GetLengths()),
// Sequence<0, 1>,
// 1, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
// ygrad_thread_data_on_block_idx);
// performs for y // performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, DataType,
...@@ -1986,29 +1955,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1986,29 +1955,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global // save z to global
if(is_dropout) if(p_z_grid)
{ {
if(p_z_grid) // P_dropped
{ blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// P_dropped decltype(z_tenor_buffer),
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
decltype(z_tenor_buffer), s_slash_p_thread_buf, ph, z_tenor_buffer);
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer); z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_thread_copy_vgpr_to_global.Run( z_tenor_buffer,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_grid_buf);
z_tenor_buffer, }
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, else
z_grid_buf); {
} // P_dropped
else blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
{ s_slash_p_thread_buf, ph);
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment