Commit 866e309f authored by aska-0096's avatar aska-0096
Browse files

hipgraph on=2.17905us

parent 0dbe5370
...@@ -148,10 +148,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -148,10 +148,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 256, 256, 32, 128, 256,
16, 16, 16, 16,
32, 32, 32, 32,
1, 2, 1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
...@@ -320,7 +320,7 @@ int main(int argc, char* argv[]) ...@@ -320,7 +320,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 500});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -20,6 +20,8 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -20,6 +20,8 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if 0
printf("HipGraph OFF\n");
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
...@@ -70,6 +72,53 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -70,6 +72,53 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventDestroy(stop)); hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
#elif 1
printf("HipGraph ON\n");
hipGraph_t graph_;
hipStream_t stream_;
HIP_CHECK_ERROR(hipStreamCreate(&stream_));
StreamConfig sc{stream_};
HIP_CHECK_ERROR(hipStreamBeginCapture(sc.stream_id_, hipStreamCaptureModeGlobal));
for(int i_r = 0; i_r < stream_config.nrepeat_; i_r++)
{
kernel<<<grid_dim, block_dim, lds_byte, sc.stream_id_>>>(args...);
}
HIP_CHECK_ERROR(hipStreamEndCapture(sc.stream_id_, &graph_));
hipGraphExec_t instance_;
HIP_CHECK_ERROR(hipGraphInstantiate(&instance_, graph_, nullptr, nullptr, 0));
hipEvent_t start_, stop_;
HIP_CHECK_ERROR(hipEventCreate(&start_));
HIP_CHECK_ERROR(hipEventCreate(&stop_));
// warm-up
for(int i_r = 0; i_r < stream_config.cold_niters_; i_r++)
{
kernel<<<grid_dim, block_dim, lds_byte, sc.stream_id_>>>(args...);
}
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_, sc.stream_id_));
HIP_CHECK_ERROR(hipGraphLaunch(instance_, sc.stream_id_));
HIP_CHECK_ERROR(hipEventRecord(stop_, sc.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop_));
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipGraphDestroy(graph_));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start_, stop_));
return total_time / stream_config.nrepeat_;
#endif
} }
else else
{ {
......
...@@ -39,6 +39,7 @@ __global__ void ...@@ -39,6 +39,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
{ {
return;
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment