".github/vscode:/vscode.git/clone" did not exist on "ad55ce61001c132e70a07231c2491387af3805fa"
Commit 0dd17574 authored by Jing Zhang's avatar Jing Zhang
Browse files

add instances

parent b2ba0a69
add_executable(client_grouped_gemm_fixed_nk grouped_gemm_fixed_nk.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -20,20 +20,23 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -20,20 +20,23 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddBias = ck::tensor_operation::element_wise::AddBias;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using DsDataType = ck::Tuple<>; using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using DsLayout = ck::Tuple<>; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row; using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = PassThrough; using CDEElementOp = AddBias;
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
...@@ -90,24 +93,22 @@ int main() ...@@ -90,24 +93,22 @@ int main()
} }
}; };
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, e_dev_bufs; std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count); a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count); b_dev_bufs.reserve(group_count);
d0_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count); e_dev_bufs.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_e; std::vector<void*> p_e;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_e.reserve(group_count); p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<>> std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_; grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count); grouped_gemm_kernel_args_.reserve(group_count);
...@@ -117,26 +118,27 @@ int main() ...@@ -117,26 +118,27 @@ int main()
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) * b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
d0_dev_bufs.emplace_back(sizeof(D0DataType) *
f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) * e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}}); gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 0, StrideBs[i], 0, {0}});
p_a.push_back(a_dev_bufs[i].GetDeviceBuffer());
p_b.push_back(b_dev_bufs[i].GetDeviceBuffer());
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), grouped_gemm_kernel_args_.push_back(
b_dev_bufs[i].GetDeviceBuffer(), {a_dev_bufs[i].GetDeviceBuffer(),
{}, b_dev_bufs[i].GetDeviceBuffer(),
e_dev_bufs[i].GetDeviceBuffer(), std::array<const void*, 1>{d0_dev_bufs[i].GetDeviceBuffer()},
Ms[i], e_dev_bufs[i].GetDeviceBuffer(),
Ns[i], Ms[i],
Ks[i], Ns[i],
StrideAs[i], Ks[i],
StrideBs[i], StrideAs[i],
{}, StrideBs[i],
StrideEs[i]}); std::array<ck::index_t, 1>{0},
StrideEs[i]});
} }
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
...@@ -168,24 +170,23 @@ int main() ...@@ -168,24 +170,23 @@ int main()
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device operation instances // profile device operation instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 1>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
// op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(), hipGetErrorString(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
......
...@@ -88,7 +88,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -88,7 +88,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<std::array<const void*, 1>> p_Ds;
std::vector<void*> p_Cs; std::vector<void*> p_Cs;
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
...@@ -201,15 +200,14 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -201,15 +200,14 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
p_Ds.push_back(std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m, gemm_descs.push_back({sum_of_m,
problem_size.Ns[i], problem_size.Ns[i],
problem_size.Ks[i], problem_size.Ks[i],
problem_size.stride_As[i], 0,
problem_size.stride_Bs[i], problem_size.stride_Bs[i],
problem_size.stride_Cs[i], 0,
{0}}); {0}});
grouped_gemm_kernel_args_.push_back( grouped_gemm_kernel_args_.push_back(
...@@ -233,8 +231,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -233,8 +231,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {}; std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {}; std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {};
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
......
...@@ -569,9 +569,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -569,9 +569,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size"); throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
} }
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) || NumDTensor == 0)) if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{ {
throw std::runtime_error("wrong! group_count_ != p_Ds"); throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
} }
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size()))) if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
...@@ -602,7 +603,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -602,7 +603,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]); p_ds_grid[j] =
static_cast<const DDataType*>(p_Ds.size() == 0 ? nullptr : p_Ds[i][j]);
}); });
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -616,6 +618,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -616,6 +618,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
StrideDs[j] = gemm_descs[i].stride_Ds_[j]; StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]); M, N, gemm_descs[i].stride_Ds_[j]);
......
...@@ -97,6 +97,7 @@ using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; ...@@ -97,6 +97,7 @@ using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
using AddBias = ck::tensor_operation::element_wise::AddBias;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -16,57 +16,59 @@ namespace tensor_operation { ...@@ -16,57 +16,59 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
// Row, Row,
// Empty_Tuple, Row_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, F32_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); AddBias>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col, Col,
Empty_Tuple, Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
Empty_Tuple, F32_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); AddBias>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances( #if 0
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(
// Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
// Empty_Tuple, Row,
// Row, Row_Tuple,
// F16, Row,
// F16, F16,
// Empty_Tuple, F16,
// F16, F32_Tuple,
// PassThrough, F16,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough,
AddBias>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
// Col, Col,
// Empty_Tuple, Row_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, F32_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); AddBias>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -77,27 +79,27 @@ template <typename ALayout, ...@@ -77,27 +79,27 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout, ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
Empty_Tuple, Row_Tuple,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
Empty_Tuple, F32_Tuple,
EDataType, EDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>> AddBias>>
{ {
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout, using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
Empty_Tuple, Row_Tuple,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
Empty_Tuple, F32_Tuple,
EDataType, EDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>; AddBias>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -106,26 +108,26 @@ struct DeviceOperationInstanceFactory< ...@@ -106,26 +108,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
//{ {
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
//} }
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
} }
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
//{ {
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs); //add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs);
//} }
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
//{ {
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs); //add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs);
//} }
} }
return op_ptrs; return op_ptrs;
} }
......
...@@ -7,5 +7,4 @@ add_instance_library(device_grouped_gemm_instance ...@@ -7,5 +7,4 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
) )
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -30,7 +30,8 @@ using DsDataType = ck::Tuple<D0DataType>; ...@@ -30,7 +30,8 @@ using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using Add = ck::tensor_operation::element_wise::AddBias; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
...@@ -63,12 +64,12 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( ...@@ -63,12 +64,12 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
DsLayout, DsLayout,
Row, Row,
F16, F16,
F32, F16,
DsDataType, DsDataType,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) Add>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -75,6 +75,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan ...@@ -75,6 +75,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_bias_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment