"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "b8d64385a915c3faba13d443f509bd3149dfc0f7"
Commit 48d356fc authored by Jing Zhang's avatar Jing Zhang
Browse files

add a instance

parent 5d61cd96
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -30,6 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -30,6 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddBias = ck::tensor_operation::element_wise::AddBias;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
...@@ -48,28 +49,7 @@ using ELayout = Row; ...@@ -48,28 +49,7 @@ using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
struct Add using CDEElementOp = AddBias;
{
template <typename E, typename C, typename D0>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0) const;
template <>
__host__ __device__ void
operator()<ck::half_t, float, float>(ck::half_t& e, const float& c, const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, float>(ck::half_t& e,
const ck::half_t& c,
const float& d0) const
{
e = c + d0;
}
};
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
......
...@@ -10,6 +10,72 @@ namespace ck { ...@@ -10,6 +10,72 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
struct Add
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 + x1;
};
};
struct ScaleAdd struct ScaleAdd
{ {
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {} __host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
......
...@@ -120,6 +120,27 @@ struct PassThrough ...@@ -120,6 +120,27 @@ struct PassThrough
} }
}; };
struct AddBias
{
template <typename E, typename C, typename D0>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0) const;
template <>
__host__ __device__ void
operator()<ck::half_t, float, float>(ck::half_t& e, const float& c, const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, float>(ck::half_t& e,
const ck::half_t& c,
const float& d0) const
{
e = c + d0;
}
};
struct UnaryConvert struct UnaryConvert
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -16,7 +16,7 @@ namespace tensor_operation { ...@@ -16,7 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances( // void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row, // Row,
// Empty_Tuple, // Empty_Tuple,
...@@ -29,7 +29,7 @@ namespace instance { ...@@ -29,7 +29,7 @@ namespace instance {
// PassThrough, // PassThrough,
// PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col, Col,
Empty_Tuple, Empty_Tuple,
...@@ -42,7 +42,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( ...@@ -42,7 +42,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_km_kn_mn_instances( // void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, // std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Row, // Row,
// Empty_Tuple, // Empty_Tuple,
...@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( ...@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
// PassThrough, // PassThrough,
// PassThrough>>>& instances); // PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_km_nk_mn_instances( // void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, // std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Col, // Col,
// Empty_Tuple, // Empty_Tuple,
...@@ -109,22 +109,22 @@ struct DeviceOperationInstanceFactory< ...@@ -109,22 +109,22 @@ struct DeviceOperationInstanceFactory<
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && // if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>) // is_same_v<ELayout, Row>)
//{ //{
// add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
//} //}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
} }
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && // if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>) // is_same_v<ELayout, Row>)
//{ //{
// add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_km_kn_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs);
//} //}
// if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && // if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>) // is_same_v<ELayout, Row>)
//{ //{
// add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_km_nk_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs);
//} //}
} }
return op_ptrs; return op_ptrs;
......
...@@ -24,42 +24,47 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -24,42 +24,47 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>; using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using Add = ck::tensor_operation::element_wise::AddBias;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances =
// clang-format off std::tuple<
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col, Col,
Empty_Tuple, DsLayout,
Row, Row,
F16, F16,
F16, F32,
Empty_Tuple, DsDataType,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -67,7 +72,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( ...@@ -67,7 +72,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment