"vscode:/vscode.git/clone" did not exist on "f2756253e6874a5af8d22ec37462d1ce75d99c94"
Commit 8c7d03ec authored by Jing Zhang's avatar Jing Zhang
Browse files

add setElementwiseOp

parent 2a964f40
......@@ -243,8 +243,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::vector<void*> p_Cs = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -265,6 +264,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
......
......@@ -50,9 +50,9 @@ struct DeviceGroupedGemmMultiABD : public BaseOperator
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmMultiABDDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{}) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -56,6 +56,10 @@ struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
virtual void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op) const = 0;
};
} // namespace device
......
......@@ -453,9 +453,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>&,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{
grid_size_ = 0;
......@@ -754,9 +754,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc> gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
{
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
......@@ -771,9 +771,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
......@@ -814,6 +814,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
return str.str();
}
static void SetElementwiseOps(Argument& arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
arg.a_element_op_ = a_element_op;
arg.b_element_op_ = b_element_op;
arg.c_element_op_ = c_element_op;
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
......@@ -825,6 +835,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) const override
{
SetElementwiseOps(
*dynamic_cast<Argument*>(p_arg), a_element_op, b_element_op, c_element_op);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment