Commit 0b472e28 authored by ltqin's avatar ltqin
Browse files

group remove y_grid_desc_mblock_mperblock_oblock_operblock parameter

parent 15713b20
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate int init_method = 1; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true; bool time_kernel = true;
// Overall QKV matrices shape // Overall QKV matrices shape
......
...@@ -174,37 +174,36 @@ __global__ void ...@@ -174,37 +174,36 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_a_grid + a_batch_offset, p_b_grid + b_batch_offset,
p_b_grid + b_batch_offset, z_matrix_ptr,
z_matrix_ptr, p_b1_grid + b1_batch_offset,
p_b1_grid + b1_batch_offset, p_lse_grid + lse_batch_offset,
p_lse_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_d_grid + lse_batch_offset, p_ygrad_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_qgrad_grid + a_batch_offset,
p_qgrad_grid + a_batch_offset, p_kgrad_grid + b_batch_offset,
p_kgrad_grid + b_batch_offset, p_vgrad_grid + b1_batch_offset,
p_vgrad_grid + b1_batch_offset, p_shared,
p_shared, a_element_op,
a_element_op, b_element_op,
b_element_op, acc_element_op,
acc_element_op, b1_element_op,
b1_element_op, c_element_op,
c_element_op, a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
b_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, b1_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, lse_grid_desc_m,
lse_grid_desc_m, d_grid_desc_m,
d_grid_desc_m, ygrad_grid_desc_o0_m_o1,
ygrad_grid_desc_o0_m_o1, block_2_ctile_map,
block_2_ctile_map, c0_matrix_mask,
c0_matrix_mask, p_drop,
p_drop, ph,
ph, z_random_matrix_offset,
z_random_matrix_offset, raw_n_padded,
raw_n_padded, i);
i);
} }
} }
else else
......
...@@ -173,37 +173,36 @@ __global__ void ...@@ -173,37 +173,36 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_a_grid + a_batch_offset, p_b_grid + b_batch_offset,
p_b_grid + b_batch_offset, z_matrix_ptr,
z_matrix_ptr, p_b1_grid + b1_batch_offset,
p_b1_grid + b1_batch_offset, p_lse_grid + lse_batch_offset,
p_lse_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_d_grid + lse_batch_offset, p_ygrad_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_qgrad_grid + a_batch_offset,
p_qgrad_grid + a_batch_offset, p_kgrad_grid + b_batch_offset,
p_kgrad_grid + b_batch_offset, p_vgrad_grid + b1_batch_offset,
p_vgrad_grid + b1_batch_offset, p_shared,
p_shared, a_element_op,
a_element_op, b_element_op,
b_element_op, acc_element_op,
acc_element_op, b1_element_op,
b1_element_op, c_element_op,
c_element_op, a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
b_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, b1_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, lse_grid_desc_m,
lse_grid_desc_m, d_grid_desc_m,
d_grid_desc_m, ygrad_grid_desc_m0_o_m1,
ygrad_grid_desc_m0_o_m1, block_2_ctile_map,
block_2_ctile_map, c0_matrix_mask,
c0_matrix_mask, p_drop,
p_drop, ph,
ph, z_random_matrix_offset,
z_random_matrix_offset, raw_n_padded,
raw_n_padded, i);
i);
} }
} }
else else
......
...@@ -179,7 +179,6 @@ __global__ void ...@@ -179,7 +179,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -216,7 +215,6 @@ __global__ void ...@@ -216,7 +215,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -733,8 +731,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -733,8 +731,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -885,21 +881,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -885,21 +881,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o))
{
y_grid_desc_mblock_mperblock_oblock_operblock =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
...@@ -973,7 +958,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -973,7 +958,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, lse_grid_desc_m,
k_grid_desc_n_k, k_grid_desc_n_k,
......
...@@ -179,7 +179,6 @@ __global__ void ...@@ -179,7 +179,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -216,7 +215,6 @@ __global__ void ...@@ -216,7 +215,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -740,8 +738,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -740,8 +738,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -891,21 +887,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -891,21 +887,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o))
{
y_grid_desc_mblock_mperblock_oblock_operblock =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o);
}
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
...@@ -975,7 +960,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -975,7 +960,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, lse_grid_desc_m,
k_grid_desc_n_k, k_grid_desc_n_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment