Commit 0c3cfcf8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comment

parent 881ba357
...@@ -182,18 +182,19 @@ int main() ...@@ -182,18 +182,19 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); SimpleDeviceMem grouped_gemm_kernel_args_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(), hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
op_ptr->GetWorkSpaceSize(argument_ptr.get()), op_ptr->GetWorkSpaceSize(argument_ptr.get()),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
...@@ -244,11 +245,12 @@ int main() ...@@ -244,11 +245,12 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); SimpleDeviceMem grouped_gemm_kernel_args_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
} }
......
...@@ -307,7 +307,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -307,7 +307,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return pass; return pass;
} }
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ProblemSize problem_size; ProblemSize problem_size;
......
...@@ -496,7 +496,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -496,7 +496,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
}; };
// Argument // Argument
...@@ -605,9 +604,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -605,9 +604,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
if(group_id * grid_size_grp != grid_size_) if(group_id * grid_size_grp != grid_size_)
{ {
throw std::runtime_error("wrong! grid_size_grp is not identical!"); throw std::runtime_error("wrong! grid_size_grp is not identical!");
...@@ -655,9 +651,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -655,9 +651,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map, local_b2c_tile_map});
BlockStart,
BlockEnd});
} }
group_id++; group_id++;
...@@ -777,8 +771,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -777,8 +771,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ - const index_t grid_size_grp = arg.grid_size_ / arg.group_count_;
arg.gemm_desc_kernel_arg_[0].BlockStart_;
const void* kernel_args_dev = nullptr; const void* kernel_args_dev = nullptr;
...@@ -798,6 +791,11 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -798,6 +791,11 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
} }
} }
if(arg.p_workspace_ == nullptr)
{
throw std::runtime_error("wrong! arg.p_workspace_ == nullptr");
}
hipGetErrorString( hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(), grouped_gemm_kernel_args.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment