Commit bee4e344 authored by aska-0096's avatar aska-0096
Browse files

(5/5) attention pass, todo: debug lds perf bug

parent fd4ff3a7
...@@ -53,8 +53,8 @@ using DeviceConvFwdInstance = ...@@ -53,8 +53,8 @@ using DeviceConvFwdInstance =
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
1, // Prefetch stage 1, // Prefetch stage
128, // BlockSize 128, // BlockSize
64, // MPerBlock 64, // MPerBlock
64, // NPerBlock 64, // NPerBlock
64, // KPerBlock 64, // KPerBlock
4, // K1 4, // K1
16, // MPerWMMA 16, // MPerWMMA
......
...@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA ...@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
...@@ -365,58 +365,57 @@ struct BlockwiseGemmWMMA ...@@ -365,58 +365,57 @@ struct BlockwiseGemmWMMA
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... // k=0,kpack*1, ... read B
// read B b_thread_copy_.Run(
b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1,
b_block_desc_k0_n0_n1_n2_k1, make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), b_block_buf,
b_block_buf, b_thread_desc_,
b_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0),
make_tuple(I0, n0, I0, I0, I0, I0), b_thread_buf);
b_thread_buf); // read A
// read A a_thread_copy_.Run(
a_thread_copy_.Run( a_block_desc_k0_m0_m1_m2_k1,
a_block_desc_k0_m0_m1_m2_k1, make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, m0, I0, I0, I0, I0),
make_tuple(I0, m0, I0, I0, I0, I0), a_thread_buf);
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { vector_type<FloatA, WmmaK> a_thread_vec;
b_thread_vec.template AsType<FloatB>()(i) = vector_type<FloatB, WmmaK> b_thread_vec;
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, static_for<0, WmmaK, 1>{}([&](auto i) {
n0, b_thread_vec.template AsType<FloatB>()(i) =
0, b_thread_buf[Number<b_thread_desc_.CalculateOffset(
(i / B_K1) % B_KRow, make_tuple(i / B_K1 / B_KRow,
0, n0,
i % B_K1))>{}]; 0,
a_thread_vec.template AsType<FloatA>()(i) = (i / B_K1) % B_KRow,
a_thread_buf[Number<a_thread_desc_.CalculateOffset( 0,
make_tuple(i / A_K1 / A_KRow, i % B_K1))>{}];
m0, a_thread_vec.template AsType<FloatA>()(i) =
0, a_thread_buf[Number<a_thread_desc_.CalculateOffset(
(i / A_K1) % A_KRow, make_tuple(i / A_K1 / A_KRow,
0, m0,
i % A_K1))>{}]; 0,
}); (i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
} }
......
...@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = 16;
...@@ -175,13 +176,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -175,13 +176,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
else else
{ {
return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1( return Transform::
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(
Number<WmmaK>{}, Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec,
Number<MRepeat>{}, a_gs_ms_ks_strides_vec),
Number<MWaves>{}, Number<WmmaK>{},
Number<MPerWmma>{}, Number<MRepeat>{},
Number<K1>{}); Number<MWaves>{},
Number<MPerWmma>{},
Number<K1>{});
} }
} }
...@@ -197,14 +200,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -197,14 +200,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
else else
{ {
return Transform::MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1( return Transform::
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(
b0_gs_ls_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
Number<WmmaK>{}, b0_gs_ls_ks_strides_vec),
Number<LRepeat>{}, Number<WmmaK>{},
Number<LWaves>{}, Number<LRepeat>{},
Number<LPerWmma>{}, Number<LWaves>{},
Number<K1>{}); Number<LPerWmma>{},
Number<K1>{});
} }
} }
...@@ -220,14 +224,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -220,14 +224,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
else else
{ {
return Transform::MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1( return Transform::
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(
b1_gs_ns_ls_strides_vec), Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
Number<WmmaK>{}, b1_gs_ns_ls_strides_vec),
Number<NRepeat>{}, Number<WmmaK>{},
Number<NWaves>{}, Number<NRepeat>{},
Number<NPerWmma>{}, Number<NWaves>{},
Number<L1>{}); Number<NPerWmma>{},
Number<L1>{});
} }
} }
...@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else else
{ {
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) * return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5); arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
} }
}(); }();
...@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< "CSpec" << getTensorSpecializationString(CSpec) << ", " << "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << getMaskingSpecializationString(MaskingSpec)
<< ">" << ">"
<< " NumPrefetch: " << " AEnableLds: "
<< AEnableLds << ", "
<< "B0EnableLds: "
<< B0EnableLds << ", "
<< "B1EnableLds: "
<< B1EnableLds << ", "
<< "NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
<< "LoopScheduler: " << "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
...@@ -468,26 +468,25 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -468,26 +468,25 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
has_main_k_block_loop>; // Last Option is W/O has_main_k_block_loop>; // Last Option is W/O
return return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config, kernel,
kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_ds_grid_,
arg.p_ds_grid_, arg.p_e_grid_,
arg.p_e_grid_, arg.a_grid_desc,
arg.a_grid_desc, arg.b_grid_desc,
arg.b_grid_desc, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock, arg.a_element_op_,
arg.a_element_op_, arg.b_element_op_,
arg.b_element_op_, arg.cde_element_op_,
arg.cde_element_op_, arg.block_2_ctile_map_);
arg.block_2_ctile_map_);
}; };
if(GridwiseOp::CalculateHasMainKBlockLoop(K)) if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{ {
return launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
......
...@@ -398,21 +398,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -398,21 +398,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
has_main_k_block_loop>; has_main_k_block_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_, arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
......
...@@ -243,10 +243,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -243,10 +243,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / AK1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, AK1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<MRepeat>{} * AK1, AK1, AK1, AK1, AK1, I1)); Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
AK1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * AK1,
Number<K0PerWmma>{} * AK1,
Number<K0PerWmma>{} * AK1,
AK1,
AK1,
AK1,
I1));
} }
}(); }();
...@@ -277,10 +290,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -277,10 +290,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / BK1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<LRepeat>{}, I1, I1, I1, BK1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<LRepeat>{} * BK1, BK1, BK1, BK1, BK1, I1)); Number<LRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
BK1),
make_tuple(Number<LRepeat>{} * Number<K0PerWmma>{} * BK1,
Number<K0PerWmma>{} * BK1,
Number<K0PerWmma>{} * BK1,
BK1,
BK1,
BK1,
I1));
} }
}(); }();
...@@ -310,10 +336,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -310,10 +336,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
constexpr auto LWmmaPerblock = LPerBlock / WmmaL; constexpr auto LWmmaPerblock = LPerBlock / WmmaL;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread constexpr auto L0PerWmma = WmmaL / 2 / BL1;
// LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<LWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, BL1), make_tuple(Number<LWmmaPerblock>{},
make_tuple(Number<NRepeat>{} * BL1, BL1, BL1, BL1, BL1, I1)); Number<NRepeat>{},
I1,
Number<L0PerWmma>{},
I1,
I1,
BL1),
make_tuple(Number<NRepeat>{} * Number<L0PerWmma>{} * BL1,
Number<L0PerWmma>{} * BL1,
Number<L0PerWmma>{} * BL1,
BL1,
BL1,
BL1,
I1));
} }
}(); }();
...@@ -333,7 +372,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -333,7 +372,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -353,7 +392,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -353,7 +392,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -371,7 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -371,7 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{ {
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -387,44 +426,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -387,44 +426,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
// Workaround, Freeze transform constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
return transform_tensor_descriptor(
ABlockDesc_{}, return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
make_tuple(make_freeze_transform(I0), Number<MRepeat>{},
make_pass_through_transform(Number<KWmma>{}), I1,
make_pass_through_transform(Number<MRepeat>{}), Number<A_KRow>{},
make_pass_through_transform(I1), I1,
make_pass_through_transform(I1), Number<A_K1>{}));
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -439,44 +466,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -439,44 +466,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds) if constexpr(B0EnableLds)
{ {
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_{}, B0BlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})), Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_LRepeat_LWave_KRow_LPerWmma_K1 -> K0_LRepeat_Lwaves_LPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0); constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
B0BlockDesc_{}, Number<LRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<B_KRow>{},
make_pass_through_transform(Number<LRepeat>{}), I1,
make_pass_through_transform(I1), Number<B_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -489,14 +505,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -489,14 +505,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{ {
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
constexpr auto A_LRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{}, A1BlockDesc_AL0_M_AL1{},
make_tuple(make_pass_through_transform(Number<A_L0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<A_L0>{}, A_LRow)),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, I1, I1)), make_unmerge_transform(make_tuple(Number<MRepeat>{}, I1, I1)),
make_pass_through_transform(Number<A_L1>{})), make_pass_through_transform(Number<A_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
template <typename B1BlockDesc_> template <typename B1BlockDesc_>
...@@ -507,44 +523,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -507,44 +523,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds) if constexpr(B1EnableLds)
{ {
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
constexpr auto B_LRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_{}, B1BlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_L0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})), make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1 constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0);
constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0); constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I5); constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<LWmma * L0PerWmma>{},
B1BlockDesc_{}, Number<NRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<LWmma>{}), Number<B_LRow>{},
make_pass_through_transform(Number<NRepeat>{}), I1,
make_pass_through_transform(I1), Number<B_L1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -610,9 +613,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -610,9 +613,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4), a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5)); a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
} }
}; };
...@@ -625,9 +628,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -625,9 +628,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) *
b0_grid_desc.GetLength(I4), b0_grid_desc.GetLength(I5),
b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) * b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) *
b0_grid_desc.GetLength(I5)); b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6));
} }
}; };
...@@ -640,9 +643,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -640,9 +643,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else else
{ {
return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) * return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) *
b1_grid_desc.GetLength(I4), b1_grid_desc.GetLength(I5),
b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) * b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) *
b1_grid_desc.GetLength(I5)); b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6));
} }
}; };
...@@ -884,6 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -884,6 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
...@@ -896,11 +900,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -896,11 +900,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -908,6 +913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -908,6 +913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index(0, make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma), m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -960,6 +966,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -960,6 +966,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>( auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>(
b0_block_desc.GetElementSpaceSize()); b0_block_desc.GetElementSpaceSize());
...@@ -972,11 +979,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -972,11 +979,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<LRepeat>{}, Number<LRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
B0BlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector,
B0ThreadTransferSrcResetCoordinateAfterRun, B0ThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -984,6 +992,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -984,6 +992,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index(0, make_multi_index(0,
0/(LWaves * LPerWmma), 0/(LWaves * LPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -1054,7 +1063,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1054,7 +1063,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0);
} }
else{ else{
return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0); return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -1063,7 +1072,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1063,7 +1072,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0);
} }
else{ else{
return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0); return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -1072,7 +1081,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1072,7 +1081,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
} }
else{ else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
} }
}(); }();
...@@ -1208,6 +1218,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1208,6 +1218,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
constexpr auto L0PerWmma = WmmaL/2/L1Value;
auto b1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B1DataType>( auto b1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B1DataType>(
b1_block_desc.GetElementSpaceSize()); b1_block_desc.GetElementSpaceSize());
...@@ -1220,11 +1231,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1220,11 +1231,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence<Number<LWmmaPerBlock>{}, Sequence<Number<LWmmaPerBlock>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<L0PerWmma>{},
I1, I1,
I1, I1,
Number<L1Value>{}>, Number<L1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
B1BlockTransferSrcScalarPerVector, B1BlockTransferSrcScalarPerVector,
B1ThreadTransferSrcResetCoordinateAfterRun, B1ThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -1232,6 +1244,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1232,6 +1244,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index(0, make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma), n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -1262,7 +1275,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1262,7 +1275,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
KPack, KPack,
false, false,
B1EnableLds, B1EnableLds,
true>{make_tuple(0, 0, 0, 0, 0)}; true>{make_tuple(0, 0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
...@@ -1271,7 +1284,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1271,7 +1284,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return b0_grid_desc.GetLength(I1); return b0_grid_desc.GetLength(I1);
} }
else{ else{
return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I4); return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5);
} }
}(); }();
......
...@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename MPerWmma, typename MPerWmma,
typename AK1> typename AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1( MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k, const AGridDesc_M_K& a_grid_desc_m_k,
const WmmaK&, const WmmaK&,
const MRepeat&, const MRepeat&,
...@@ -194,17 +194,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -194,17 +194,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const MPerWmma&, const MPerWmma&,
const AK1&) const AK1&)
{ {
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock; const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK{}; const auto AKWmma = K / WmmaK{};
constexpr auto AKRow = WmmaK{} / AK1{}; constexpr auto AKRow = 2;
constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, AKRow, AK1{})), make_tuple(make_unmerge_transform(
make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))), make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
// //
...@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename LPerWmma, typename LPerWmma,
typename BK1> typename BK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1( MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k, const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&, const WmmaK&,
const LRepeat&, const LRepeat&,
...@@ -262,17 +264,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -262,17 +264,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const LPerWmma&, const LPerWmma&,
const BK1&) const BK1&)
{ {
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock; const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1); const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{}; const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = WmmaK{} / BK1{}; constexpr auto BKRow = 2;
constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_l_k, b_grid_desc_l_k,
make_tuple(make_unmerge_transform(make_tuple(BKWmma, BKRow, BK1{})), make_tuple(make_unmerge_transform(
make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))), make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
// //
...@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename NPerWmma, typename NPerWmma,
typename BL1> typename BL1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1( MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l, const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&, const WmmaL&,
const NRepeat&, const NRepeat&,
...@@ -331,17 +335,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -331,17 +335,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const NPerWmma&, const NPerWmma&,
const BL1&) const BL1&)
{ {
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock; const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1); const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{}; const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = WmmaL{} / BL1{}; constexpr auto BLRow = 2;
constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_n_l, b_grid_desc_n_l,
make_tuple(make_unmerge_transform(make_tuple(BLWmma, BLRow, BL1{})), make_tuple(make_unmerge_transform(
make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))), make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment