"...composable_kernel_rocm.git" did not exist on "2778e99758e149a6cb5309ca307bf7c1e61a562f"
Commit 227076ba authored by ltqin's avatar ltqin
Browse files

change O to N

parent 796b544e
...@@ -115,7 +115,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -115,7 +115,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_N_
{ {
static_assert(BlockSize_ == BlockSliceLength_M_); static_assert(BlockSize_ == BlockSliceLength_M_);
static constexpr auto ThreadSliceLength_M = Number<1>{}; static constexpr auto ThreadSliceLength_M = Number<1>{};
...@@ -134,7 +134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -134,7 +134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
using DstBufType = StaticBuffer<AddressSpaceEnum::Vgpr, FloatD, ThreadSliceLength_M, true>; using DstBufType = StaticBuffer<AddressSpaceEnum::Vgpr, FloatD, ThreadSliceLength_M, true>;
}; };
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>; using YDotYGrad_M_N = YDotYGrad_M_N_<BlockSize, MPerBlock, NPerBlock>;
__device__ static void Run(const InputDataType* __restrict__ p_y_grid, __device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
...@@ -169,20 +169,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -169,20 +169,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
constexpr auto d_thread_desc_mblock_mrepeat_mwave_mperxdl = constexpr auto d_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto y_thread_desc_m0_m1_n0_n1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O)); I1, YDotYGrad_M_N::ThreadSliceLength_M, I1, YDotYGrad_M_N::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc = constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1, make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M, YDotYGrad_M_N::ThreadClusterLength_M,
I1, I1,
YDotYGrad_M_O::ThreadClusterLength_O>{}, YDotYGrad_M_N::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{}); Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx = const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx = const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths(); y_thread_cluster_idx * y_thread_desc_m0_m1_n0_n1.GetLengths();
const auto y_thread_data_on_grid_idx = const auto y_thread_data_on_grid_idx =
make_multi_index( make_multi_index(
...@@ -194,19 +194,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -194,19 +194,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
InputDataType, InputDataType,
FloatD, FloatD,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_n0_n1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_n0_n1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, // SrcVectorDim 3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector YDotYGrad_M_N::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock, false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{}; auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_N::DstBufType{};
// clear accum buffers // clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear(); y_dot_ygrad_thread_accum_buf.Clear();
...@@ -216,19 +216,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -216,19 +216,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_grid_buf, y_grid_buf,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_n0_n1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_thread_buf); y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_nblock_nperblock,
ygrad_grid_buf, ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_n0_n1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
ygrad_thread_buf); ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) { static_for<0, YDotYGrad_M_N::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) { static_for<0, YDotYGrad_M_N::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset = constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO)); y_thread_desc_m0_m1_n0_n1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) += y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}]; y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment