Commit da8c0608 authored by rocking's avatar rocking
Browse files

[What] Suport non pointer for invoker and argument

[Why] Snyc coding style with gemm
parent 2fc2a189
...@@ -274,25 +274,25 @@ int main() ...@@ -274,25 +274,25 @@ int main()
reduceMeanSquare_device_buf.SetZero(); reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
auto normalize = DeviceNormalizeInstance{}; auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker_ptr = normalize.MakeInvokerPointer(); auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = auto normalize_argument = normalize.MakeArgument(
normalize.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
reduceMean_device_buf.GetDeviceBuffer(), static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
reduceMeanSquare_device_buf.GetDeviceBuffer(), static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
gamma_device_buf.GetDeviceBuffer(), static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
beta_device_buf.GetDeviceBuffer(), static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
layerNorm_device_buf.GetDeviceBuffer(), static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
{M, N}, {M, N},
{StrideC, 1}, {StrideC, 1},
{1, 0}, {1, 0},
{1, 0}, {1, 0},
{0, 1}, {0, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
NormalizeFunctor{}); NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument.get())) if(!normalize.IsSupportedArgument(normalize_argument))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise_Xdl_CShuffle instance, exiting!"); "Device5AryElementwise_Xdl_CShuffle instance, exiting!");
...@@ -300,7 +300,7 @@ int main() ...@@ -300,7 +300,7 @@ int main()
// run kernel // run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
normalize_invoker_ptr->Run(normalize_argument.get(), StreamConfig{nullptr, time_kernel}); normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
{ {
......
...@@ -215,6 +215,8 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator ...@@ -215,6 +215,8 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
} }
}; };
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
...@@ -260,6 +262,37 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator ...@@ -260,6 +262,37 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{p_a,
p_b,
p_c,
p_d,
p_e,
p_f,
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
functor};
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_c, const void* p_c,
...@@ -291,8 +324,9 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator ...@@ -291,8 +324,9 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
}; }; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment