Commit 77a04d6a authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed register loads

parent 7d700bc0
......@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......@@ -192,8 +195,8 @@ struct BlockwiseGemmWMMA
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
static_assert(AEnableLds == true, "only support EnableLds");
static_assert(BEnableLds == true, "only support EnableLds");
// static_assert(AEnableLds == true, "only support EnableLds");
// static_assert(BEnableLds == true, "only support EnableLds");
}
// transposed WMMA output C' = B' * A'
......@@ -316,7 +319,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -326,7 +329,8 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0),
make_tuple(
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -372,7 +376,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -380,7 +384,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -442,44 +446,30 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
using AThreadCopyType =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
using BThreadCopyType =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
AThreadCopyType a_thread_copy_;
BThreadCopyType b_thread_copy_;
};
#else
template <index_t BlockSize,
......
......@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
......@@ -331,10 +331,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// constexpr auto A_KRow = I1;
constexpr auto A_KRow = I2;
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
......@@ -372,10 +371,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
// constexpr auto B_KRow = I1;
constexpr auto B_KRow = I2;
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
......@@ -412,8 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
// constexpr auto A_LRow = I1;
constexpr auto A_LRow = I2;
constexpr auto A_LRow = I1;
return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_L0>{}, A_LRow)),
......@@ -431,10 +428,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
// constexpr auto B_LRow = I1;
constexpr auto B_LRow = I2;
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
constexpr auto B_LRow = I1;
return transform_tensor_descriptor(
B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
......@@ -1183,7 +1179,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
MRepeat,
NRepeat,
KPack,
true,
false,
B1EnableLds,
true>{make_tuple(0, 0, 0, 0, 0, 0)};
......@@ -1346,7 +1342,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds();
//blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
......@@ -1369,7 +1365,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds();
//blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
}
} // end gemm1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment