Commit 27f1c8cb authored by Jing Zhang's avatar Jing Zhang
Browse files

use array

parent b78c8719
...@@ -82,7 +82,7 @@ int main(int argc, char* argv[]) ...@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b; std::vector<const void*> p_a, p_b;
std::vector<std::vector<const void*>> p_ds; std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_c; std::vector<void*> p_c;
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
......
...@@ -200,7 +200,7 @@ int main(int argc, char* argv[]) ...@@ -200,7 +200,7 @@ int main(int argc, char* argv[])
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<std::vector<const void*>> p_Ds = {}; std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
......
...@@ -28,10 +28,12 @@ template <typename ALayout, ...@@ -28,10 +28,12 @@ template <typename ALayout,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator struct DeviceGroupedGemm : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a, MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<std::vector<const void*>>& p_ds, std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e, std::vector<void*>& p_e,
std::vector<GemmDesc>& gemm_desc, std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -532,7 +532,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -532,7 +532,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
{ {
Argument(std::vector<const void*>& p_As, Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -755,7 +755,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -755,7 +755,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeArgument(std::vector<const void*>& p_As, static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs, std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -769,9 +769,10 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -769,9 +769,10 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_As, std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -175,7 +175,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -175,7 +175,7 @@ bool profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
auto p_ds = std::vector<std::vector<const void*>>{}; auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : op_ptrs) for(auto& gemm_ptr : op_ptrs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment