Commit 0c3cfcf8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comment

parent 881ba357
......@@ -182,18 +182,19 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_kernel_args_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetWorkSpaceSize(argument_ptr.get()),
hipMemcpyHostToDevice));
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
......@@ -244,11 +245,12 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_kernel_args_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
......
......@@ -307,7 +307,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return pass;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
ProblemSize problem_size;
......
......@@ -496,7 +496,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
......@@ -605,9 +604,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
if(group_id * grid_size_grp != grid_size_)
{
throw std::runtime_error("wrong! grid_size_grp is not identical!");
......@@ -655,9 +651,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map,
BlockStart,
BlockEnd});
local_b2c_tile_map});
}
group_id++;
......@@ -777,8 +771,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
const index_t grid_size_grp = arg.grid_size_ / arg.group_count_;
const void* kernel_args_dev = nullptr;
......@@ -798,6 +791,11 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
if(arg.p_workspace_ == nullptr)
{
throw std::runtime_error("wrong! arg.p_workspace_ == nullptr");
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment