Unverified Commit f04ec574 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #915 from ROCmSoftwarePlatform/mha-train-hipGraph

Use hipGraph capturing and replaying method for forward verification in examples/52_xxx
parents d5832ed4 437d35a2
......@@ -244,6 +244,18 @@ int run(int argc, char* argv[])
if(do_verification)
{
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
// run for storing z tensor
argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......@@ -277,9 +289,19 @@ int run(int argc, char* argv[])
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
HIP_CHECK_ERROR(hipMemsetAsync(
c_device_buf.GetDeviceBuffer(), 0, c_device_buf.GetBufferSize(), stream));
HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.GetDeviceBuffer(), 0, lse_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
......
......@@ -303,6 +303,18 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeRelaxed));
argument =
gemm.MakeArgument(p_a,
p_b0,
......@@ -326,7 +338,16 @@ int run(int argc, char* argv[])
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphDebugDotPrint(graph, "grouped_fwd_debug.dot", 0x007f));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
for(std::size_t i = 0; i < group_count; i++)
{
......
......@@ -15,3 +15,16 @@ inline void hip_check_error(hipError_t x)
throw std::runtime_error(ss.str());
}
}
#define HIP_CHECK_ERROR(flag) \
do \
{ \
hipError_t _tmpVal; \
if((_tmpVal = flag) != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
......@@ -5,6 +5,7 @@
#include <iostream>
#include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
......@@ -912,10 +913,34 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipStreamCaptureStatus status = hipStreamCaptureStatusNone;
HIP_CHECK_ERROR(hipStreamIsCapturing(stream_config.stream_id_, &status));
if(status == hipStreamCaptureStatusActive)
{
size_t copy_size = arg.group_kernel_args_.size() * sizeof(GroupKernelArg);
// ToDO: when to release this memory buffer?
char* persistent_ptr = new char[copy_size];
(void)std::memcpy(persistent_ptr, arg.group_kernel_args_.data(), copy_size);
HIP_CHECK_ERROR(hipMemcpyAsync(arg.p_workspace_,
persistent_ptr,
copy_size,
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
HIP_CHECK_ERROR(
hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment