Commit c416b823 authored by danyao12's avatar danyao12
Browse files

recover gridwise decoder

parent f01a06c4
......@@ -1867,11 +1867,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// gemm dP
// dP = dY * V^T
......
......@@ -1775,11 +1775,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
q_grid_desc_k0_m_k1,
......
......@@ -1798,11 +1798,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1,
......
......@@ -1721,11 +1721,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
//
// calculate Y dot dY
......
......@@ -926,11 +926,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
......
......@@ -1080,11 +1080,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
// if(c0_matrix_mask.IsTileSkippable(
// m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
// {
// continue;
// }
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment