Commit 579f84c6 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 7e003d31
......@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault,
256, // BlockSize
128, // MPerBlock
16, // NPerBlock
128, // NPerBlock
32, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M Repeat
1, // N-Repeat
2, // M Repeat
4, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 128, 1, 2>,
S<1, 64, 1, 4>,
8>;
// clang-format on
......
......@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
default:
......
......@@ -129,7 +129,7 @@ using DeviceGemmInstance =
S<0, 2, 1>,
1,
8,
1,
1, // be eight?
false,
1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle
......
......@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf("Warm up 1 time\n");
#endif
// warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 100;
const int nrepeat = 1;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
......
......@@ -27,6 +27,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
......@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
......
......@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
......
......@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Enabled, Mat-B Lds Enabled\n");
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
......@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Disabled, Mat-B Lds Enabled\n");
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
......
......@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment