Commit 579f84c6 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 7e003d31
...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault, GemmDefault,
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
16, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
1, // M Repeat 2, // M Repeat
1, // N-Repeat 4, // N-Repeat
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 128, 1, 2>, S<1, 64, 1, 4>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break; break;
default: default:
......
...@@ -129,7 +129,7 @@ using DeviceGemmInstance = ...@@ -129,7 +129,7 @@ using DeviceGemmInstance =
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
8, 8,
1, 1, // be eight?
false, false,
1, // CShuffleMWmmaPerWavePerShuffle 1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle 2, // CShuffleNWmmaPerWavePerShuffle
......
...@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf("Warm up 1 time\n"); printf("Warm up 1 time\n");
#endif #endif
// warm up // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); // kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 100; const int nrepeat = 1;
#if DEBUG_LOG #if DEBUG_LOG
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
#endif #endif
......
...@@ -27,6 +27,8 @@ template <index_t BlockSize, ...@@ -27,6 +27,8 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false> bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data /* Option: Read from LDS, big buffer hold all threads required data
* Source * Source
...@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA ...@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication. // Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1; static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1; static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
......
...@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory // Describe how data read from Global memory
......
...@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true> ...@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Enabled, Mat-B Lds Enabled\n");
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
}, },
Number<a_block_desc.GetLengths().GetSize()>{}); Number<a_block_desc.GetLengths().GetSize()>{});
#endif #endif
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Disabled, Mat-B Lds Enabled\n");
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0); constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
......
...@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NPerWmma, NPerWmma,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment