Commit 9809f5d4 authored by danyao12's avatar danyao12
Browse files

remove hardcode

parent de4494a2
......@@ -357,9 +357,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = AK1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......
......@@ -364,8 +364,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = AK1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -556,7 +557,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<BK1>{});
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// Z in Gemm0 C position
......
......@@ -304,9 +304,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = AK1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......
......@@ -311,8 +311,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = AK1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -494,7 +495,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<BK1>{});
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// Z in Gemm0 C position
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment