"README_origin.md" did not exist on "a0cac22cab9fe74763a001ffdaffa52e84671e60"
Commit f41a265a authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix allocation and setting workspace pointer.

parent defa2071
...@@ -228,10 +228,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -228,10 +228,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument), gemm.GetDeviceKernelArgSize(&argument),
...@@ -247,7 +243,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -247,7 +243,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatchSize(argument, config.k_batch); gemm.SetKBatchSize(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
if(config.time_kernel) if(config.time_kernel)
{ {
...@@ -289,6 +288,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -289,6 +288,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
} }
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
} }
return pass; return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment