"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "9e332f068a5affae6e5c4209a5e630f7f913f828"
Unverified Commit c7a96ed5 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

add p_workspace to baseargument (#275)

parent 6eb55499
...@@ -15,6 +15,8 @@ struct BaseArgument ...@@ -15,6 +15,8 @@ struct BaseArgument
BaseArgument& operator=(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {} virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
}; };
struct BaseInvoker struct BaseInvoker
...@@ -42,7 +44,11 @@ struct BaseOperator ...@@ -42,7 +44,11 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {} virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl ...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr; p_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
...@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl ...@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl ...@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
} }
hipGetErrorString( hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_, hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
...@@ -507,13 +505,13 @@ struct DeviceGroupedGemmXdl ...@@ -507,13 +505,13 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation, CElementwiseOperation,
true>; true>;
ave_time = launch_and_time_kernel( ave_time =
stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -531,13 +529,13 @@ struct DeviceGroupedGemmXdl ...@@ -531,13 +529,13 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation, CElementwiseOperation,
false>; false>;
ave_time = launch_and_time_kernel( ave_time =
stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl ...@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg); return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
}; };
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment