Commit 5ef0843b authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Update to the forward examples under example/52_xxx to using hipGraph...

Update to the forward examples under example/52_xxx to using hipGraph capturing and replaying for verification
parent e29a9111
......@@ -244,6 +244,18 @@ int run(int argc, char* argv[])
if(do_verification)
{
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
// run for storing z tensor
argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......@@ -277,9 +289,19 @@ int run(int argc, char* argv[])
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
HIP_CHECK_ERROR(hipMemsetAsync(
c_device_buf.GetDeviceBuffer(), 0, c_device_buf.GetBufferSize(), stream));
HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.GetDeviceBuffer(), 0, lse_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
......
......@@ -303,6 +303,16 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
argument =
gemm.MakeArgument(p_a,
p_b0,
......@@ -326,8 +336,17 @@ int run(int argc, char* argv[])
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
invoker.Run(argument, StreamConfig{nullptr, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
for(std::size_t i = 0; i < group_count; i++)
{
const int& G0 = g0_g1_m_n_k_o[i][0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment