Commit 51a549c9 authored by Jing Zhang's avatar Jing Zhang
Browse files

moved hipMemAlloc outside of deviceOp

parent d8f1458f
......@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = 4;
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
......@@ -189,11 +189,20 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(gemm_shapes.size()));
// do GEMM
auto argument = gemm.MakeArgument(p_a,
p_b,
p_c,
gemm_shapes,
gemm_desc_workspace.GetDeviceBuffer(),
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -51,6 +51,7 @@ struct DeviceGroupedGemm : public BaseOperator
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -350,6 +350,7 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -363,6 +364,8 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;
gemm_descs_args_workspace_ = gemm_descs_args_workspace;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
......@@ -437,6 +440,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_;
};
......@@ -485,12 +490,13 @@ struct DeviceGroupedGemmXdl
}
}
void* gemm_descs_const_;
hipGetErrorString(hipMalloc(
&gemm_descs_const_, arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg)));
// void* gemm_descs_args_workspace;
// hipGetErrorString(hipMalloc(
// &gemm_descs_args_workspace, arg.gemm_desc_kernel_arg_.size() *
// sizeof(GemmDescKernelArg)));
hipGetErrorString(
hipMemcpy(gemm_descs_const_,
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
......@@ -515,7 +521,7 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(gemm_descs_const_),
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -539,7 +545,7 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(gemm_descs_const_),
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -581,11 +587,21 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
return Argument{p_a,
p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -595,13 +611,22 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
return std::make_unique<Argument>(p_a,
p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
......@@ -632,6 +657,11 @@ struct DeviceGroupedGemmXdl
return str.str();
}
static size_t GetWorkSpaceSize(const index_t group_count)
{
return group_count * sizeof(GemmDescKernelArg);
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment