"vscode:/vscode.git/clone" did not exist on "2d3235f8e99eaf26ecc3671ea3498fa9d9c097f5"
Commit f41a265a authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix allocation and setting workspace pointer.

parent defa2071
......@@ -228,10 +228,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
......@@ -247,7 +243,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatchSize(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
if(config.time_kernel)
{
......@@ -289,6 +288,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment