Unverified Commit c7a96ed5 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

add p_workspace to baseargument (#275)

parent 6eb55499
......@@ -15,6 +15,8 @@ struct BaseArgument
BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
};
struct BaseInvoker
......@@ -42,7 +44,11 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
virtual ~BaseOperator() {}
};
......
......@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr;
p_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
......@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_;
};
......@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
}
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
......@@ -507,13 +505,13 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
true>;
ave_time = launch_and_time_kernel(
stream_config,
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -531,13 +529,13 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
false>;
ave_time = launch_and_time_kernel(
stream_config,
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment