Commit 27f1c8cb authored by Jing Zhang's avatar Jing Zhang
Browse files

use array

parent b78c8719
......@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<std::vector<const void*>> p_ds;
std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_c;
gemm_descs.reserve(group_count);
......
......@@ -200,7 +200,7 @@ int main(int argc, char* argv[])
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::vector<const void*>> p_Ds = {};
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
......
......@@ -28,10 +28,12 @@ template <typename ALayout,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<std::vector<const void*>>& p_ds,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op,
......
......@@ -532,7 +532,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
......@@ -755,7 +755,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op,
......@@ -769,14 +769,15 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
......
......@@ -175,7 +175,7 @@ bool profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
auto p_ds = std::vector<std::vector<const void*>>{};
auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment