Commit 5509e684 authored by fsx950223's avatar fsx950223
Browse files

fix arguments

parent 2ebc3248
...@@ -665,15 +665,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -665,15 +665,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<const DataType*>& p_As, Argument(const std::vector<const void*>& p_As,
const std::vector<const DataType*>& p_Bs, const std::vector<const void*>& p_Bs,
const std::vector<const DataType*>& p_B1s, const std::vector<const void*>& p_B1s,
const std::vector<const DataType*>& p_Cs, // for dS const std::vector<const void*>& p_Cs, // for dS
const std::vector<const LSEDataType*>& p_LSEs, const std::vector<const void*>& p_LSEs,
const std::vector<const DataType*>& p_Ygrads, const std::vector<const void*>& p_Ygrads,
std::vector<DataType*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<DataType*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<DataType*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::array<void*, NumAcc0Bias>& p_acc0_biases,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::array<void*, NumAcc1Bias>& p_acc1_biases,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
...@@ -1042,15 +1042,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1042,15 +1042,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg); return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
} }
static auto MakeArgument(const std::vector<const DataType*>& p_As, static auto MakeArgument(const std::vector<const void*>& p_As,
const std::vector<const DataType*>& p_Bs, const std::vector<const void*>& p_Bs,
const std::vector<const DataType*>& p_B1s, const std::vector<const void*>& p_B1s,
const std::vector<const DataType*>& p_Cs, // for dS const std::vector<const void*>& p_Cs, // for dS
const std::vector<const LSEDataType*>& p_LSEs, const std::vector<const void*>& p_LSEs,
const std::vector<const DataType*>& p_Ygrads, const std::vector<const void*>& p_Ygrads,
std::vector<DataType*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<DataType*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<DataType*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::array<void*, NumAcc0Bias>& p_acc0_biases,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::array<void*, NumAcc1Bias>& p_acc1_biases,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
...@@ -1084,15 +1084,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1084,15 +1084,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// polymorphic // polymorphic
// FIXME: constness // FIXME: constness
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<const DataType*>& p_As, MakeArgumentPointer(const std::vector<const void*>& p_As,
const std::vector<const DataType*>& p_Bs, const std::vector<const void*>& p_Bs,
const std::vector<const DataType*>& p_B1s, const std::vector<const void*>& p_B1s,
const std::vector<const DataType*>& p_Cs, // for dS const std::vector<const void*>& p_Cs, // for dS
const std::vector<const LSEDataType*>& p_LSEs, const std::vector<const void*>& p_LSEs,
const std::vector<const DataType*>& p_Ygrads, const std::vector<const void*>& p_Ygrads,
std::vector<DataType*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<DataType*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<DataType*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::array<void*, NumAcc0Bias>& p_acc0_biases,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::array<void*, NumAcc1Bias>& p_acc1_biases,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment