Commit 866e309f authored by aska-0096's avatar aska-0096
Browse files

hipgraph on=2.17905us

parent 0dbe5370
......@@ -148,10 +148,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 256, 256,
32, 128, 256,
16, 16,
32, 32,
1, 2,
1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
......@@ -320,7 +320,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 500});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......
......@@ -20,6 +20,8 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
#if 0
printf("HipGraph OFF\n");
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
......@@ -70,6 +72,53 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat;
#elif 1
printf("HipGraph ON\n");
hipGraph_t graph_;
hipStream_t stream_;
HIP_CHECK_ERROR(hipStreamCreate(&stream_));
StreamConfig sc{stream_};
HIP_CHECK_ERROR(hipStreamBeginCapture(sc.stream_id_, hipStreamCaptureModeGlobal));
for(int i_r = 0; i_r < stream_config.nrepeat_; i_r++)
{
kernel<<<grid_dim, block_dim, lds_byte, sc.stream_id_>>>(args...);
}
HIP_CHECK_ERROR(hipStreamEndCapture(sc.stream_id_, &graph_));
hipGraphExec_t instance_;
HIP_CHECK_ERROR(hipGraphInstantiate(&instance_, graph_, nullptr, nullptr, 0));
hipEvent_t start_, stop_;
HIP_CHECK_ERROR(hipEventCreate(&start_));
HIP_CHECK_ERROR(hipEventCreate(&stop_));
// warm-up
for(int i_r = 0; i_r < stream_config.cold_niters_; i_r++)
{
kernel<<<grid_dim, block_dim, lds_byte, sc.stream_id_>>>(args...);
}
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_, sc.stream_id_));
HIP_CHECK_ERROR(hipGraphLaunch(instance_, sc.stream_id_));
HIP_CHECK_ERROR(hipEventRecord(stop_, sc.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop_));
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipGraphDestroy(graph_));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start_, stop_));
return total_time / stream_config.nrepeat_;
#endif
}
else
{
......
......@@ -39,6 +39,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
{
return;
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment