"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "57c3b3644cbdc0bb508de0eed4a9546fb793d061"
Commit da8c0608 authored by rocking's avatar rocking
Browse files

[What] Suport non pointer for invoker and argument

[Why] Snyc coding style with gemm
parent 2fc2a189
......@@ -274,25 +274,25 @@ int main()
reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker_ptr = normalize.MakeInvokerPointer();
auto normalize_argument =
normalize.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
layerNorm_device_buf.GetDeviceBuffer(),
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument.get()))
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise_Xdl_CShuffle instance, exiting!");
......@@ -300,7 +300,7 @@ int main()
// run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
normalize_invoker_ptr->Run(normalize_argument.get(), StreamConfig{nullptr, time_kernel});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
bool pass = true;
{
......
......@@ -215,6 +215,8 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
}
};
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
......@@ -260,6 +262,37 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
return true;
};
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{p_a,
p_b,
p_c,
p_d,
p_e,
p_f,
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
functor};
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_c,
......@@ -291,8 +324,9 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
functor);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment