"...composable_kernel.git" did not exist on "5d37d7bff4e631c3b94112c31a52f209ca39dfe2"
Commit 5ef0843b authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Update to the forward examples under example/52_xxx to using hipGraph...

Update to the forward examples under example/52_xxx to using hipGraph capturing and replaying for verification
parent e29a9111
...@@ -244,6 +244,18 @@ int run(int argc, char* argv[]) ...@@ -244,6 +244,18 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
// run for storing z tensor // run for storing z tensor
argument = argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
...@@ -277,9 +289,19 @@ int run(int argc, char* argv[]) ...@@ -277,9 +289,19 @@ int run(int argc, char* argv[])
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
c_device_buf.SetZero(); HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.SetZero(); c_device_buf.GetDeviceBuffer(), 0, c_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{nullptr, false}); HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.GetDeviceBuffer(), 0, lse_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
......
...@@ -303,6 +303,16 @@ int run(int argc, char* argv[]) ...@@ -303,6 +303,16 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
argument = argument =
gemm.MakeArgument(p_a, gemm.MakeArgument(p_a,
p_b0, p_b0,
...@@ -326,8 +336,17 @@ int run(int argc, char* argv[]) ...@@ -326,8 +336,17 @@ int run(int argc, char* argv[])
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1_m_n_k_o[i][0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment