Commit 3ccfb0ae authored by aska-0096's avatar aska-0096
Browse files

(2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

parent c713d224
......@@ -80,33 +80,33 @@ using DeviceOpInstance =
BElementOp,
CDEElementOp,
GemmSpec,
2,
1,
128,
64,
128,
64,
8,
64,
4,
16,
16,
2,
1,
4,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
1,
1,
S<1, 32, 1, 4>,
S<1, 64, 1, 2>,
8>;
int main(int argc, char* argv[])
......
......@@ -87,6 +87,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
......@@ -98,8 +99,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......@@ -144,18 +145,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
else
{
constexpr auto A_KRow = WmmaK / K1;
constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -195,18 +199,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
else
{
constexpr auto B_KRow = WmmaK / K1;
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -438,14 +445,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5);
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
}
}();
float ave_time = 0;
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
GridwiseOp,
ADataType,
......@@ -462,9 +466,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BElementwiseOperation,
CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
true>; // Last Option is W/O
has_main_k_block_loop>; // Last Option is W/O
ave_time =
return
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -482,48 +486,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.b_element_op_,
arg.cde_element_op_,
arg.block_2_ctile_map_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc,
arg.b_grid_desc,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.block_2_ctile_map_);
return launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
......
......@@ -382,11 +382,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
}
}();
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_gemm_wmma<
GridwiseGemm,
ADataType,
......@@ -400,9 +396,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O
has_main_k_block_loop>;
ave_time = launch_and_time_kernel(stream_config,
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
......@@ -417,42 +413,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
const auto kernel = kernel_gemm_wmma<
GridwiseGemm,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
return launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
......
......@@ -379,10 +379,23 @@ struct GridwiseGemmMultipleD_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -413,10 +426,23 @@ struct GridwiseGemmMultipleD_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -436,7 +462,7 @@ struct GridwiseGemmMultipleD_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -456,7 +482,7 @@ struct GridwiseGemmMultipleD_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -471,45 +497,33 @@ struct GridwiseGemmMultipleD_Wmma
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}));
}
}();
......@@ -525,42 +539,31 @@ struct GridwiseGemmMultipleD_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}));
}
}();
......@@ -620,9 +623,9 @@ struct GridwiseGemmMultipleD_Wmma
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
}
};
......@@ -635,9 +638,9 @@ struct GridwiseGemmMultipleD_Wmma
else
{
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5));
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
}
};
......@@ -837,7 +840,8 @@ struct GridwiseGemmMultipleD_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
}
}();
......@@ -888,8 +892,9 @@ struct GridwiseGemmMultipleD_Wmma
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
......@@ -902,11 +907,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -914,6 +920,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......@@ -967,7 +974,8 @@ struct GridwiseGemmMultipleD_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
......@@ -979,11 +987,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -991,6 +1000,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......
......@@ -655,7 +655,7 @@ struct GridwiseGemm_Wmma
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment