Commit a2b5c7ac authored by guangzlu's avatar guangzlu
Browse files

modified run_batched_multihead_attention_forward.inc

parent 016f85bf
...@@ -228,6 +228,7 @@ int run(int argc, char* argv[]) ...@@ -228,6 +228,7 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
...@@ -261,7 +262,8 @@ int run(int argc, char* argv[]) ...@@ -261,7 +262,8 @@ int run(int argc, char* argv[])
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number {seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread // of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment