Commit 0b472e28 authored by ltqin's avatar ltqin
Browse files

group remove y_grid_desc_mblock_mperblock_oblock_operblock parameter

parent 15713b20
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
int init_method = 1; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
......
......@@ -174,8 +174,7 @@ __global__ void
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......
......@@ -173,8 +173,7 @@ __global__ void
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......
......@@ -179,7 +179,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
......@@ -216,7 +215,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
......@@ -733,8 +731,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -885,21 +881,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o))
{
y_grid_desc_mblock_mperblock_oblock_operblock =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
......@@ -973,7 +958,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m,
k_grid_desc_n_k,
......
......@@ -179,7 +179,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
......@@ -216,7 +215,6 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
......@@ -740,8 +738,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -891,21 +887,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o))
{
y_grid_desc_mblock_mperblock_oblock_operblock =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o);
}
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
......@@ -975,7 +960,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m,
k_grid_desc_n_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment