Commit 77a04d6a authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed register loads

parent 7d700bc0
...@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA ...@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -192,8 +195,8 @@ struct BlockwiseGemmWMMA ...@@ -192,8 +195,8 @@ struct BlockwiseGemmWMMA
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
static_assert(AEnableLds == true, "only support EnableLds"); // static_assert(AEnableLds == true, "only support EnableLds");
static_assert(BEnableLds == true, "only support EnableLds"); // static_assert(BEnableLds == true, "only support EnableLds");
} }
// transposed WMMA output C' = B' * A' // transposed WMMA output C' = B' * A'
...@@ -316,7 +319,7 @@ struct BlockwiseGemmWMMA ...@@ -316,7 +319,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -326,7 +329,8 @@ struct BlockwiseGemmWMMA ...@@ -326,7 +329,8 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0), make_tuple(
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -372,7 +376,7 @@ struct BlockwiseGemmWMMA ...@@ -372,7 +376,7 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -380,7 +384,7 @@ struct BlockwiseGemmWMMA ...@@ -380,7 +384,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -442,13 +446,7 @@ struct BlockwiseGemmWMMA ...@@ -442,13 +446,7 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds> using AThreadCopyType =
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA, ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
...@@ -458,15 +456,8 @@ struct BlockwiseGemmWMMA ...@@ -458,15 +456,8 @@ struct BlockwiseGemmWMMA
5, 5,
A_K1, A_K1,
A_K1>; A_K1>;
};
template <bool EnableLds> using BThreadCopyType =
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB, ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
...@@ -476,10 +467,9 @@ struct BlockwiseGemmWMMA ...@@ -476,10 +467,9 @@ struct BlockwiseGemmWMMA
5, 5,
B_K1, B_K1,
B_K1>; B_K1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; AThreadCopyType a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_; BThreadCopyType b_thread_copy_;
}; };
#else #else
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
...@@ -333,8 +333,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -333,8 +333,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
constexpr auto A_KRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
...@@ -374,8 +373,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -374,8 +373,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
// constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
constexpr auto B_KRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_{}, B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
...@@ -412,8 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -412,8 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{ {
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
// constexpr auto A_LRow = I1; constexpr auto A_LRow = I1;
constexpr auto A_LRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{}, A1BlockDesc_AL0_M_AL1{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_L0>{}, A_LRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_L0>{}, A_LRow)),
...@@ -433,8 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -433,8 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
// constexpr auto B_LRow = I1; constexpr auto B_LRow = I1;
constexpr auto B_LRow = I2;
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_{}, B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
...@@ -1183,7 +1179,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1183,7 +1179,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack, KPack,
true, false,
B1EnableLds, B1EnableLds,
true>{make_tuple(0, 0, 0, 0, 0, 0)}; true>{make_tuple(0, 0, 0, 0, 0, 0)};
...@@ -1346,7 +1342,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1346,7 +1342,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds(); block_sync_lds();
//blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds(); block_sync_lds();
...@@ -1369,7 +1365,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1369,7 +1365,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds(); block_sync_lds();
//blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
} }
} // end gemm1 } // end gemm1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment