"tests/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "1e2a2c681de94d45ce997fbd8c4a6b162c091bd8"
Commit 51a549c9 authored by Jing Zhang's avatar Jing Zhang
Browse files

moved hipMemAlloc outside of deviceOp

parent d8f1458f
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 4; int group_count = rand() % 16 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
...@@ -189,11 +189,20 @@ int main(int argc, char* argv[]) ...@@ -189,11 +189,20 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(gemm_shapes.size()));
// do GEMM
auto argument = gemm.MakeArgument(p_a,
p_b,
p_c,
gemm_shapes,
gemm_desc_workspace.GetDeviceBuffer(),
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -51,6 +51,7 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -51,6 +51,7 @@ struct DeviceGroupedGemm : public BaseOperator
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes, std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
......
...@@ -350,6 +350,7 @@ struct DeviceGroupedGemmXdl ...@@ -350,6 +350,7 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes, std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -363,6 +364,8 @@ struct DeviceGroupedGemmXdl ...@@ -363,6 +364,8 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
gemm_descs_args_workspace_ = gemm_descs_args_workspace;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
...@@ -437,6 +440,8 @@ struct DeviceGroupedGemmXdl ...@@ -437,6 +440,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -485,12 +490,13 @@ struct DeviceGroupedGemmXdl ...@@ -485,12 +490,13 @@ struct DeviceGroupedGemmXdl
} }
} }
void* gemm_descs_const_; // void* gemm_descs_args_workspace;
hipGetErrorString(hipMalloc( // hipGetErrorString(hipMalloc(
&gemm_descs_const_, arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg))); // &gemm_descs_args_workspace, arg.gemm_desc_kernel_arg_.size() *
// sizeof(GemmDescKernelArg)));
hipGetErrorString( hipGetErrorString(
hipMemcpy(gemm_descs_const_, hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
...@@ -515,7 +521,7 @@ struct DeviceGroupedGemmXdl ...@@ -515,7 +521,7 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(gemm_descs_const_), cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -539,7 +545,7 @@ struct DeviceGroupedGemmXdl ...@@ -539,7 +545,7 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(gemm_descs_const_), cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -581,11 +587,21 @@ struct DeviceGroupedGemmXdl ...@@ -581,11 +587,21 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes, std::vector<GemmShape> gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; return Argument{p_a,
p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -595,13 +611,22 @@ struct DeviceGroupedGemmXdl ...@@ -595,13 +611,22 @@ struct DeviceGroupedGemmXdl
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes, std::vector<GemmShape>& gemm_shapes,
void* gemm_descs_args_workspace,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(p_a,
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); p_b,
p_c,
gemm_shapes,
gemm_descs_args_workspace,
1,
1,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
...@@ -632,6 +657,11 @@ struct DeviceGroupedGemmXdl ...@@ -632,6 +657,11 @@ struct DeviceGroupedGemmXdl
return str.str(); return str.str();
} }
static size_t GetWorkSpaceSize(const index_t group_count)
{
return group_count * sizeof(GemmDescKernelArg);
}
}; };
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment