Commit b41172a3 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Some tiny updates

parent b23b3d71
...@@ -918,7 +918,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -918,7 +918,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -448,19 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -448,19 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{});
}
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides) const std::vector<index_t>& v_gs_os_ns_strides)
...@@ -988,7 +975,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -988,7 +975,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -795,7 +795,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -795,7 +795,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -1620,7 +1620,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1620,7 +1620,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // ygrad dynamic buffer used for calculating y_dot_dy
const auto ygrad_grid_buf1 = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
// ygrad dynamic buffer used for calculating dV = Pdrop^T * dY or dPdrop = dY * V^T
const auto ygrad_grid_buf2 = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize()); p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
...@@ -2208,7 +2212,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2208,7 +2212,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_thread_buf); y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf, ygrad_grid_buf1,
y_thread_desc_m0_m1_o0_o1, y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
ygrad_thread_buf); ygrad_thread_buf);
...@@ -2492,7 +2496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2492,7 +2496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// preload data into LDS // preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf); ygrad_grid_buf2);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step); ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
...@@ -2513,7 +2517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2513,7 +2517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
gemm1_a_thread_buf); gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf); ygrad_grid_buf2);
block_sync_lds(); block_sync_lds();
...@@ -2553,7 +2557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2553,7 +2557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
// preload data into LDS // preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf); ygrad_grid_buf2);
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step); ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
...@@ -2570,7 +2574,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2570,7 +2574,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) { static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) {
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf); ygrad_grid_buf2);
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment