Commit ecdc7e9d authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactoring

* Disable logging
* extract out of if statement KBatch update.
parent 9eed0992
...@@ -165,7 +165,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -165,7 +165,7 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
...@@ -196,32 +196,33 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -196,32 +196,33 @@ bool profile_grouped_gemm_impl(int do_verification,
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(kbatch > 1)
{ {
if(kbatch > 1) using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
{ {
using DeviceOpSplitK = dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout, ->SetKBatchSize(argument_ptr.get(), kbatch);
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
{
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch);
}
} }
}
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel) if(time_kernel)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment