Commit 272b7574 authored by danyao12's avatar danyao12
Browse files

fix drop==0 compiler issue in prototype1

parent 63c2d069
......@@ -400,9 +400,9 @@ int run(int argc, char* argv[])
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
......
......@@ -1265,7 +1265,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
......@@ -1718,36 +1717,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// // performs for y
// auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
// decltype(y_thread_desc_m0_m1_o0_o1),
// decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
// Sequence<0, 1, 2, 3>,
// 3, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
// y_thread_data_on_grid_idx);
// // performs for ygrad
// auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
// decltype(ygrad_thread_desc_m_o),
// decltype(ygrad_thread_desc_m_o.GetLengths()),
// Sequence<0, 1>,
// 1, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
// ygrad_thread_data_on_block_idx);
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
......@@ -1986,8 +1955,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(is_dropout)
{
if(p_z_grid)
{
// P_dropped
......@@ -1996,8 +1963,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -2009,7 +1975,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment