Commit bab3161b authored by ltqin's avatar ltqin
Browse files

regular code

parent cf9ef868
......@@ -78,18 +78,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
}
__host__ __device__ static constexpr auto
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const ORSGridDesc_M& lse_grid_desc_m)
MakeORSGridDescriptor_MBlock_MPerBlock(const ORSGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
const auto lse_grid_desc_mblock_mperblock = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return lse_grid_desc_mblock_mrepeat_mwave_mperxdl;
return lse_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -185,15 +185,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
// if(get_thread_global_1d_id() == 1)
// {
// printf("y_thread_data_on_block_idx:{ %d, %d, %d,%d}, get_thread_local_1d_id: %d\n",
// y_thread_data_on_block_idx[I0],
// y_thread_data_on_block_idx[I1],
// y_thread_data_on_block_idx[I2],
// y_thread_data_on_block_idx[I3],
// get_thread_local_1d_id());
// }
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
......@@ -253,14 +244,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
auto ors_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(ors_grid_desc_m);
auto ors_grid_desc_mblock_mperblock =
MakeORSGridDescriptor_MBlock_MPerBlock(ors_grid_desc_m);
auto ors_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatORS,
FloatORS,
decltype(ors_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(ors_grid_desc_mblock_mrepeat_mwave_mperxdl),
decltype(ors_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, 1>,
Sequence<0, 1>,
......@@ -268,7 +259,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
1,
InMemoryDataOperationEnum::Set,
1,
false>{ors_grid_desc_mblock_mrepeat_mwave_mperxdl,
false>{ors_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperxdl
ck::tensor_operation::element_wise::PassThrough{}};
......@@ -277,11 +268,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
ors_thread_copy_vgpr_to_global.Run(ors_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0),
y_dot_ygrad_thread_accum_buf,
ors_grid_desc_mblock_mrepeat_mwave_mperxdl,
ors_grid_desc_mblock_mperblock,
ors_grid_buf);
ignore = ors_thread_copy_vgpr_to_global;
ignore = ors_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment