Commit dc8309db authored by aska-0096's avatar aska-0096
Browse files

Skip A_Lds sanity pass, Skip B_Lds scratch occured

parent a4694341
...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
8, // M Repeat 8, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
1, // N-Repeat 1, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -51,16 +51,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -51,16 +51,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, 8,
8, 8,
true, true,
S<4, 16, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
1, // C shuffle (M Repeat) Per store 4, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 16, 1, 16>, S<1, 32, 1, 8>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up // warm up
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); // kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 1; const int nrepeat = 100;
#if DEBUG_LOG #if DEBUG_LOG
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
#endif #endif
......
...@@ -298,58 +298,123 @@ struct BlockwiseGemmWMMA ...@@ -298,58 +298,123 @@ struct BlockwiseGemmWMMA
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... // basic intrinsic to determine loopover direction
static_for<0, MRepeat, 1>{}([&](auto m0) { if constexpr(MRepeat < NRepeat)
// read A {
a_thread_copy_.Run( static_for<0, KPerBlock / WmmaK, 1>{}(
a_block_desc_k0_m0_m1_m2_k1, [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
make_tuple( static_for<0, MRepeat, 1>{}([&](auto m0) {
Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, m0, I0, I0, I0), // read A
a_block_buf, a_thread_copy_.Run(
a_thread_desc_, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{},
a_thread_buf); m0,
I0,
static_for<0, NRepeat, 1>{}([&](auto n0) { I0,
// read B I0),
b_thread_copy_.Run( a_block_buf,
b_block_desc_k0_n0_n1_n2_k1, a_thread_desc_,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{}, make_tuple(I0, m0, I0, I0, I0),
n0, a_thread_buf);
I0,
I0, static_for<0, NRepeat, 1>{}([&](auto n0) {
I0), // read B
b_block_buf, b_thread_copy_.Run(
b_thread_desc_, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, I0, I0, I0), make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
b_thread_buf); n0,
I0,
vector_type<FloatA, WmmaK> a_thread_vec; I0,
vector_type<FloatB, WmmaK> b_thread_vec; I0),
b_block_buf,
static_for<0, WmmaK, 1>{}([&](auto i) { b_thread_desc_,
a_thread_vec.template AsType<FloatA>()(i) = make_tuple(I0, n0, I0, I0, I0),
a_thread_buf[Number<a_thread_desc_.CalculateOffset( b_thread_buf);
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) = vector_type<FloatA, WmmaK> a_thread_vec;
b_thread_buf[Number<b_thread_desc_.CalculateOffset( vector_type<FloatB, WmmaK> b_thread_vec;
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); }
else
}); {
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
n0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{},
m0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
} }
protected: protected:
......
...@@ -89,8 +89,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -89,8 +89,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// Unconditional enable double side LDS if uncommented following
// Force enable LDS if uncommented following
// AEnableLds = true; // AEnableLds = true;
// BEnableLds = true; // BEnableLds = true;
...@@ -223,53 +222,53 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -223,53 +222,53 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_Wmma< using GridwiseGemm =
BlockSize, GridwiseGemm_Wmma<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc, AGridDesc,
BGridDesc, BGridDesc,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerWmma, MPerWmma,
NPerWmma, NPerWmma,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds, AEnableLds,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds, BEnableLds,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle, CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
NumPrefetch, NumPrefetch,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -572,7 +571,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -572,7 +571,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<< MRepeat << ", " << MRepeat << ", "
<< NRepeat << NRepeat
<< ">" << ">"
<< " NumPrefetch: " << " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
<< "LoopScheduler: " << "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment