Commit 268c497c authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed lds_enabled

parent 26e8ba9f
......@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
4, // K1
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
......@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
S<1, 0, 2>,
S<1, 0, 2>,
2,
1,
1,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
1,
1,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
......
......@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -327,7 +324,7 @@ struct BlockwiseGemmWMMA
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -373,7 +370,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -381,7 +378,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......
......@@ -97,7 +97,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto AEnableLds = false; //AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder =
......
......@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -292,7 +295,7 @@ struct GridwiseGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
constexpr auto A_KRow = I2;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
......@@ -307,7 +310,6 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue
......@@ -332,7 +334,7 @@ struct GridwiseGemm_Wmma
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
I1,
Number<A_K1>{}));
}
......@@ -350,7 +352,7 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
constexpr auto B_KRow = I2;
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
......@@ -365,14 +367,13 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
I1,
Number<B_K1>{}));
}
......@@ -781,8 +782,6 @@ struct GridwiseGemm_Wmma
// GEMM
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
static_assert(KPerBlock % KPack == 0, "");
auto blockwise_gemm =
BlockwiseGemmWMMA<BlockSize,
ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment