Commit 2d3d7190 authored by aska-0096's avatar aska-0096
Browse files

Enable skipLds feature and fix compilation bugs

parent 5573651e
...@@ -470,7 +470,7 @@ struct BlockwiseGemmWMMA ...@@ -470,7 +470,7 @@ struct BlockwiseGemmWMMA
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
TransposeC ? false : true>; false>;
}; };
template <bool EnableLds> template <bool EnableLds>
...@@ -504,7 +504,7 @@ struct BlockwiseGemmWMMA ...@@ -504,7 +504,7 @@ struct BlockwiseGemmWMMA
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
TransposeC ? true : false>; false>;
}; };
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
...@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
...@@ -869,13 +874,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -869,13 +874,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
else else
{ {
if(!(arg.a_kz_stride_ == 1 && if(!(arg.a_kz_stride_ == 1))
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) {
index_t LastK =
AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
if(LastK % ABlockTransferSrcScalarPerVector == 0)
{ {
printf("DeviceOp: Vector Access A-k check failure\n"); printf("DeviceOp: Vector Access A-k check failure\n");
return false; return false;
} }
} }
}
// vector memory access of B: could be on N or BK1 dimension // vector memory access of B: could be on N or BK1 dimension
if constexpr(BBlockTransferSrcVectorDim == 1) if constexpr(BBlockTransferSrcVectorDim == 1)
......
...@@ -373,10 +373,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -373,10 +373,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_{}, B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})), Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
...@@ -430,10 +434,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -430,10 +434,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_LRow = I2;
#else
constexpr auto B_LRow = I1; constexpr auto B_LRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_{}, B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})), make_pass_through_transform(Number<B_L1>{})),
......
...@@ -304,10 +304,14 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -304,10 +304,14 @@ struct GridwiseFpAintBGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -362,10 +366,14 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -362,10 +366,14 @@ struct GridwiseFpAintBGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
...@@ -506,7 +506,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -506,7 +506,7 @@ struct GridwiseGemmMultipleD_Wmma
#endif #endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -549,7 +549,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -549,7 +549,7 @@ struct GridwiseGemmMultipleD_Wmma
#endif #endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
...@@ -303,7 +303,7 @@ struct GridwiseGemm_Wmma ...@@ -303,7 +303,7 @@ struct GridwiseGemm_Wmma
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -365,7 +365,7 @@ struct GridwiseGemm_Wmma ...@@ -365,7 +365,7 @@ struct GridwiseGemm_Wmma
#endif #endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment