Unverified Commit 487da052 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp


Co-authored-by: default avatarBartłomiej Kocot <barkocot@amd.com>
parent 9715904a
...@@ -62,6 +62,19 @@ struct DeviceGroupedGemmMultiABD : public BaseOperator ...@@ -62,6 +62,19 @@ struct DeviceGroupedGemmMultiABD : public BaseOperator
static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor"); static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
/**
* \brief Make argument pointer for grouped gemm multi abd.
*
* \param p_as A pointers to the A.
* \param p_bs A pointers to the B.
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the E.
* \param gemm_desc Gemm descriptors for each group.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as, MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
std::vector<std::array<const void*, NumBTensor>>& p_bs, std::vector<std::array<const void*, NumBTensor>>& p_bs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment