Commit 268c497c authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed lds_enabled

parent 26e8ba9f
...@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
64, // MPerBlock 64, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
4, // K1 2, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
...@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
1, 2,
1, 2,
true, true,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
1, 2,
1, 2,
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
......
...@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA ...@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA ...@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -327,7 +324,7 @@ struct BlockwiseGemmWMMA ...@@ -327,7 +324,7 @@ struct BlockwiseGemmWMMA
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple( make_tuple(
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0), Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -373,7 +370,7 @@ struct BlockwiseGemmWMMA ...@@ -373,7 +370,7 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -381,7 +378,7 @@ struct BlockwiseGemmWMMA ...@@ -381,7 +378,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
......
...@@ -97,7 +97,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -97,7 +97,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = false; //AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
......
...@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma ...@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
} }
else else
{ {
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma ...@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
} }
else else
{ {
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -292,7 +295,7 @@ struct GridwiseGemm_Wmma ...@@ -292,7 +295,7 @@ struct GridwiseGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1; constexpr auto A_KRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
...@@ -307,7 +310,6 @@ struct GridwiseGemm_Wmma ...@@ -307,7 +310,6 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue // Err: merge transform cause non-constexpr issue
...@@ -332,7 +334,7 @@ struct GridwiseGemm_Wmma ...@@ -332,7 +334,7 @@ struct GridwiseGemm_Wmma
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{}, return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<A_KRow>{}, I1,
I1, I1,
Number<A_K1>{})); Number<A_K1>{}));
} }
...@@ -350,7 +352,7 @@ struct GridwiseGemm_Wmma ...@@ -350,7 +352,7 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1; constexpr auto B_KRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
...@@ -365,14 +367,13 @@ struct GridwiseGemm_Wmma ...@@ -365,14 +367,13 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform // Workaround, Freeze transform
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{}, return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<B_KRow>{}, I1,
I1, I1,
Number<B_K1>{})); Number<B_K1>{}));
} }
...@@ -781,8 +782,6 @@ struct GridwiseGemm_Wmma ...@@ -781,8 +782,6 @@ struct GridwiseGemm_Wmma
// GEMM // GEMM
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
static_assert(KPerBlock % KPack == 0, "");
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA<BlockSize, BlockwiseGemmWMMA<BlockSize,
ADataType, ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment