Commit 88578483 authored by Jing Zhang's avatar Jing Zhang
Browse files

add SetWorkSpacePointer

parent 51a549c9
......@@ -192,17 +192,13 @@ int main(int argc, char* argv[])
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(gemm_shapes.size()));
// do GEMM
auto argument = gemm.MakeArgument(p_a,
p_b,
p_c,
gemm_shapes,
gemm_desc_workspace.GetDeviceBuffer(),
a_element_op,
b_element_op,
c_element_op);
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -42,6 +42,8 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
virtual ~BaseOperator() {}
};
......
......@@ -51,7 +51,6 @@ struct DeviceGroupedGemm : public BaseOperator
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -350,7 +350,6 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -364,7 +363,7 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;
gemm_descs_args_workspace_ = gemm_descs_args_workspace;
gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
......@@ -490,11 +489,6 @@ struct DeviceGroupedGemmXdl
}
}
// void* gemm_descs_args_workspace;
// hipGetErrorString(hipMalloc(
// &gemm_descs_args_workspace, arg.gemm_desc_kernel_arg_.size() *
// sizeof(GemmDescKernelArg)));
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
......@@ -587,21 +581,11 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op};
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -611,22 +595,13 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op);
return std::make_unique<Argument>(
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
}
// polymorphic
......@@ -658,9 +633,14 @@ struct DeviceGroupedGemmXdl
return str.str();
}
static size_t GetWorkSpaceSize(const index_t group_count)
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
return group_count * sizeof(GemmDescKernelArg);
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};
......
......@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto c_element_op = PassThrough{};
// do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment