Commit 0dd17574 authored by Jing Zhang's avatar Jing Zhang
Browse files

add instances

parent b2ba0a69
add_executable(client_grouped_gemm_fixed_nk grouped_gemm_fixed_nk.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
......@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
using F16 = ck::half_t;
using F32 = float;
......@@ -20,20 +20,23 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddBias = ck::tensor_operation::element_wise::AddBias;
using ADataType = F16;
using BDataType = F16;
using DsDataType = ck::Tuple<>;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using CDEElementOp = AddBias;
struct SimpleDeviceMem
{
......@@ -90,24 +93,22 @@ int main()
}
};
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, e_dev_bufs;
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
d0_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_e;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<>>
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
......@@ -117,26 +118,27 @@ int main()
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
d0_dev_bufs.emplace_back(sizeof(D0DataType) *
f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}});
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 0, StrideBs[i], 0, {0}});
p_a.push_back(a_dev_bufs[i].GetDeviceBuffer());
p_b.push_back(b_dev_bufs[i].GetDeviceBuffer());
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
grouped_gemm_kernel_args_.push_back(
{a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
std::array<const void*, 1>{d0_dev_bufs[i].GetDeviceBuffer()},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
std::array<ck::index_t, 1>{0},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
......@@ -168,24 +170,23 @@ int main()
float best_tflops = 0;
float best_gb_per_sec = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 1>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
// op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
......
......@@ -88,7 +88,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<std::array<const void*, 1>> p_Ds;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count);
......@@ -201,15 +200,14 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
p_Ds.push_back(std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
0,
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
0,
{0}});
grouped_gemm_kernel_args_.push_back(
......@@ -233,8 +231,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
......
......@@ -569,9 +569,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) || NumDTensor == 0))
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Ds");
throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
......@@ -602,7 +603,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]);
p_ds_grid[j] =
static_cast<const DDataType*>(p_Ds.size() == 0 ? nullptr : p_Ds[i][j]);
});
// tensor descriptors for problem definiton
......@@ -616,6 +618,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
......
......@@ -97,6 +97,7 @@ using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using AddBias = ck::tensor_operation::element_wise::AddBias;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
......@@ -7,7 +7,7 @@
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......@@ -16,57 +16,59 @@ namespace tensor_operation {
namespace device {
namespace instance {
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
PassThrough,
PassThrough,
AddBias>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F32_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
AddBias>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
#if 0
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
PassThrough,
PassThrough,
AddBias>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
PassThrough,
PassThrough,
AddBias>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
......@@ -77,27 +79,27 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Empty_Tuple,
Row_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
F32_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>>
AddBias>>
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Empty_Tuple,
Row_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
F32_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>;
AddBias>;
static auto GetInstances()
{
......@@ -106,26 +108,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
//{
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
//}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
//{
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs);
//}
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
//{
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs);
//}
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
//add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
//add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
......
......@@ -7,5 +7,4 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -30,7 +30,8 @@ using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using Add = ck::tensor_operation::element_wise::AddBias;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......@@ -63,12 +64,12 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
DsLayout,
Row,
F16,
F32,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
Add>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -75,6 +75,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_bias_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment