Commit 3ccfb0ae authored by aska-0096's avatar aska-0096
Browse files

(2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

parent c713d224
...@@ -80,33 +80,33 @@ using DeviceOpInstance = ...@@ -80,33 +80,33 @@ using DeviceOpInstance =
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
2, 1,
128, 128,
64, 64,
128,
64, 64,
8, 64,
4,
16, 16,
16, 16,
2, 1,
4, 4,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
1, 1,
1, 1,
S<1, 32, 1, 4>, S<1, 64, 1, 2>,
8>; 8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -87,6 +87,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -87,6 +87,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -98,8 +99,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -98,8 +99,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
...@@ -144,18 +145,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -144,18 +145,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1; constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK; const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))), make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -195,18 +199,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -195,18 +199,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
constexpr auto B_KRow = WmmaK / K1; constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK; const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))), make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -438,14 +445,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -438,14 +445,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
else else
{ {
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) * return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5); arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
} }
}(); }();
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
...@@ -462,9 +466,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -462,9 +466,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
true>; // Last Option is W/O has_main_k_block_loop>; // Last Option is W/O
ave_time = return
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -482,48 +486,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -482,48 +486,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< return launch_kernel(integral_constant<bool, false>{});
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc,
arg.b_grid_desc,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time;
} }
// polymorphic // polymorphic
......
...@@ -382,11 +382,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -382,11 +382,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
} }
}(); }();
auto launch_kernel = [&](auto has_main_k_block_loop) {
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_wmma< const auto kernel = kernel_gemm_wmma<
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
...@@ -400,9 +396,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -400,9 +396,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O has_main_k_block_loop>;
ave_time = launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -417,42 +413,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -417,42 +413,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_gemm_wmma< return launch_kernel(integral_constant<bool, false>{});
GridwiseGemm,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time;
} }
// polymorphic // polymorphic
......
...@@ -379,10 +379,23 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -379,10 +379,23 @@ struct GridwiseGemmMultipleD_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -413,10 +426,23 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -413,10 +426,23 @@ struct GridwiseGemmMultipleD_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -436,7 +462,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -436,7 +462,7 @@ struct GridwiseGemmMultipleD_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -456,7 +482,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -456,7 +482,7 @@ struct GridwiseGemmMultipleD_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -471,45 +497,33 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -471,45 +497,33 @@ struct GridwiseGemmMultipleD_Wmma
constexpr auto a_wave_desc = [&]() { constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
return transform_tensor_descriptor( Number<MRepeat>{},
ABlockDesc_{}, I1,
make_tuple(make_freeze_transform(I0), Number<A_KRow>{},
make_pass_through_transform(Number<KWmma>{}), I1,
make_pass_through_transform(Number<MRepeat>{}), Number<A_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -525,42 +539,31 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -525,42 +539,31 @@ struct GridwiseGemmMultipleD_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
BBlockDesc_{}, Number<NRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<B_KRow>{},
make_pass_through_transform(Number<NRepeat>{}), I1,
make_pass_through_transform(I1), Number<B_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -620,9 +623,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -620,9 +623,9 @@ struct GridwiseGemmMultipleD_Wmma
else else
{ {
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4), a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5)); a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
} }
}; };
...@@ -635,9 +638,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -635,9 +638,9 @@ struct GridwiseGemmMultipleD_Wmma
else else
{ {
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5)); b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
} }
}; };
...@@ -837,7 +840,8 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -837,7 +840,8 @@ struct GridwiseGemmMultipleD_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
} }
else{ else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
} }
}(); }();
...@@ -888,8 +892,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -888,8 +892,9 @@ struct GridwiseGemmMultipleD_Wmma
else else
{ {
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
...@@ -902,11 +907,12 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -902,11 +907,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -914,6 +920,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -914,6 +920,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index(0, make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma), m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -967,7 +974,8 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -967,7 +974,8 @@ struct GridwiseGemmMultipleD_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
...@@ -979,11 +987,12 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -979,11 +987,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -991,6 +1000,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -991,6 +1000,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index(0, make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma), n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
......
...@@ -655,7 +655,7 @@ struct GridwiseGemm_Wmma ...@@ -655,7 +655,7 @@ struct GridwiseGemm_Wmma
else else
{ {
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value; constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment