"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "1ae7778447e0489cbe086fe24ea764105ffa9eb8"
Commit aef8ea3c authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use persistent host pointer when doing hipMemcpyAsync under hipGraph environment

parent 6355e068
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString( hipStreamCaptureStatus status = hipStreamCaptureStatusNone;
hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), HIP_CHECK_ERROR(hipStreamIsCapturing(stream_config.stream_id_, &status));
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice, if(status == hipStreamCaptureStatusActive)
stream_config.stream_id_)); {
size_t copy_size = arg.group_kernel_args_.size() * sizeof(GroupKernelArg);
// ToDO: when to release this memory buffer?
char* persistent_ptr = new char[copy_size];
(void)std::memcpy(persistent_ptr, arg.group_kernel_args_.data(), copy_size);
HIP_CHECK_ERROR(hipMemcpyAsync(arg.p_workspace_,
persistent_ptr,
copy_size,
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
HIP_CHECK_ERROR(
hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
float ave_time = 0; float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment