"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "3737bb039aafe4b59510bbc180f6a3d930b417ee"
Commit e08e22b6 authored by Jing Zhang's avatar Jing Zhang
Browse files

use instanceFactory

parent 1ccac727
...@@ -14,7 +14,13 @@ struct GemmDesc ...@@ -14,7 +14,13 @@ struct GemmDesc
ck::index_t stride_A_, stride_B_, stride_C_; ck::index_t stride_A_, stride_B_, stride_C_;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator struct DeviceGroupedGemm : public BaseOperator
...@@ -31,11 +37,24 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -31,11 +37,24 @@ struct DeviceGroupedGemm : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr< using DeviceGroupedGemmPtr = std::unique_ptr<DeviceGroupedGemm<ALayout,
DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -130,8 +130,15 @@ template <typename ALayout, ...@@ -130,8 +130,15 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemmXdl struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
: public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGroupedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGroupedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGroupedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGroupedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -54,7 +54,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< ...@@ -54,7 +54,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceGroupedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
......
...@@ -53,7 +53,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< ...@@ -53,7 +53,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceGroupedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
......
...@@ -54,7 +54,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< ...@@ -54,7 +54,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceGroupedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
......
...@@ -51,7 +51,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< ...@@ -51,7 +51,9 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceGroupedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp" #include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
...@@ -17,30 +19,6 @@ ...@@ -17,30 +19,6 @@
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -51,7 +29,7 @@ template <typename ADataType, ...@@ -51,7 +29,7 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
void profile_grouped_gemm_impl(int do_verification, bool profile_grouped_gemm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
...@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs) const std::vector<int>& StrideCs)
{ {
bool pass = true;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
...@@ -170,43 +151,20 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -170,43 +151,20 @@ void profile_grouped_gemm_impl(int do_verification,
p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
} }
// add device GEMM instances using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
std::vector<ck::tensor_operation::device::instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
is_same<CDataType, half_t>::value) DeviceOp>::GetInstances();
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
if(gemm_ptrs.size() <= 0) if(op_ptrs.size() <= 0)
{ {
throw std::runtime_error("wrong! no device GEMM instance found"); throw std::runtime_error("wrong! no device GEMM instance found");
} }
...@@ -217,7 +175,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -217,7 +175,7 @@ void profile_grouped_gemm_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a, gemm_ptr->MakeArgumentPointer(p_a,
...@@ -294,7 +252,8 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -294,7 +252,8 @@ void profile_grouped_gemm_impl(int do_verification,
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData); pass = pass && ck::utils::check_err(c_m_n_device_results[i].mData,
c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -319,6 +278,8 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -319,6 +278,8 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
} // namespace profiler } // namespace profiler
} // namespace profiler } // namespace profiler
......
...@@ -2,39 +2,8 @@ ...@@ -2,39 +2,8 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric>
#include <initializer_list> #include "profiler/include/profile_grouped_gemm_impl.hpp"
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr_>&);
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace { namespace {
...@@ -43,11 +12,11 @@ using BDataType = ck::half_t; ...@@ -43,11 +12,11 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) template <typename ALayout, typename BLayout, typename CLayout>
bool TestGroupedGemm()
{ {
int group_count = rand() % 10 + 1; int group_count = rand() % 10 + 1;
...@@ -56,156 +25,36 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -56,156 +25,36 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
std::vector<const void*> p_a, p_b; std::vector<const void*> p_a, p_b;
std::vector<void*> p_c; std::vector<void*> p_c;
gemm_descs.reserve(group_count); std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int M = 256 + 256 * (rand() % 10); Ms.push_back(256 + 256 * (rand() % 10));
int N = 256 + 256 * (rand() % 10); Ns.push_back(256 + 256 * (rand() % 10));
int K = 128 + 128 * (rand() % 10); Ks.push_back(128 + 128 * (rand() % 10));
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M;
gemm_descs.push_back({M, N, K, AStride, BStride, CStride});
}
auto f_host_tensor_descriptor = StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
} }
for(std::size_t i = 0; i < gemm_descs.size(); i++) return ck::profiler::profile_grouped_gemm_impl<ADataType,
{ BDataType,
a_tensors_device.emplace_back( CDataType,
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); AccDataType,
b_tensors_device.emplace_back( ALayout,
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); BLayout,
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( CLayout>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
// do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
{
return false;
}
ref_invoker.Run(ref_argument);
bool res = ck::utils::check_err(c_host_tensors[i].mData, c_device_tensors[i].mData);
std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl;
if(!res)
return false;
}
return true;
} }
} // anonymous namespace } // anonymous namespace
int main() int main()
{ {
std::vector<DeviceGroupedGemmPtr_> groupedGemmPtrs;
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs);
bool res = true; bool res = true;
for(auto& gemmPtr : groupedGemmPtrs) res = res && TestGroupedGemm<Row, Row, Row>();
{
res &= TestGroupedGemm(gemmPtr);
}
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment