Commit b0ae034e authored by ltqin's avatar ltqin
Browse files

Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask

parents 07206672 a0044283
......@@ -17,6 +17,8 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_train_v2 grouped_multihead_attention_train_v2.cpp)
add_example_executable(example_batched_multihead_attention_train_v2 batched_multihead_attention_train_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v3 grouped_multihead_attention_backward_v3.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -186,8 +186,14 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_z_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_lse_grid;
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -195,10 +201,15 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m;
ignore = vgrad_grid_desc_n_o;
ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
......
......@@ -185,8 +185,14 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_z_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_lse_grid;
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -194,10 +200,15 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m;
ignore = vgrad_grid_desc_n_o;
ignore = ygrad_grid_desc_m0_o_m1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
......
......@@ -73,7 +73,7 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -140,7 +140,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -175,7 +175,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -191,8 +191,14 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_z_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_lse_grid;
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -200,15 +206,21 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = nblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
ignore = raw_m_padded;
ignore = raw_n_padded;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......
......@@ -190,8 +190,14 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_z_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_lse_grid;
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -199,15 +205,21 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_m0_o_m1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = nblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
ignore = raw_m_padded;
ignore = raw_n_padded;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -1038,7 +1050,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time;
}
......
......@@ -168,6 +168,8 @@ __global__ void
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_z_grid;
ignore = p_lse_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -177,10 +179,17 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = lse_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_dropout_in_16bits;
ignore = p_dropout_rescale;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......
......@@ -176,6 +176,8 @@ __global__ void
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_z_grid;
ignore = p_lse_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -185,10 +187,19 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = lse_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_dropout_in_16bits;
ignore = p_dropout_rescale;
ignore = seed;
ignore = offset;
ignore = raw_m_padded;
ignore = raw_n_padded;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......
......@@ -176,6 +176,10 @@ __global__ void
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = p_dropout_in_16bits;
ignore = p_dropout_rescale;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment