Commit 646e81a6 authored by Jing Zhang's avatar Jing Zhang
Browse files

add f8_fp16 support

parent 473b76b9
......@@ -20,7 +20,8 @@
namespace ck {
template <typename GridwiseGemm,
typename ABDataType,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
......@@ -36,8 +37,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
......@@ -242,7 +243,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
......@@ -442,6 +443,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
BDataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
......
......@@ -229,10 +229,17 @@ struct MultiplyAdd
const float y = c * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
const float& c,
const float& d0,
const float& d1) const
{
const float y = c * d0 + d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
......
......@@ -32,6 +32,8 @@ using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -95,6 +97,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
......
......@@ -45,7 +45,33 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
PassThrough,
MultiplyAdd>>>&);
// GEMM + Add + Multiply
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
// GEMM + Multiply + Add
template <typename ALayout,
typename BLayout,
typename D0Layout,
......@@ -105,6 +131,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
......
add_instance_library(device_gemm_multiply_add_instance
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
)
......@@ -30,34 +30,11 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
// no padding
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// M/N/K padding
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using F32_Tuple = ck::Tuple<F32, F32>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row, Row>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
// M/N/K padding
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
F16,
F8,
F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using F32_Tuple = ck::Tuple<F32, F32>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row, Row>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
// M/N/K padding
// N % 8 == 0 && K % 1 == 0
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>
// clang-format on
>;
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Tuple,
Row,
F16,
F8,
F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -33,7 +33,7 @@ template <typename ADataType,
typename D0Layout,
typename D1Layout,
typename ELayout>
bool profile_gemm_add_multiply_impl(int do_verification,
bool profile_gemm_multiply_add_impl(int do_verification,
int init_method,
bool /*do_log*/,
bool time_kernel,
......@@ -90,11 +90,11 @@ bool profile_gemm_add_multiply_impl(int do_verification,
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddMultiply;
using CDEElementOp = MultiplyAdd;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
......
......@@ -22,7 +22,8 @@ int profile_gemm_multiply_add(int argc, char* argv[])
enum struct MatrixDataType
{
F16_F16_F16_F16_F16, // 0
F16_F16_F16_F16_F16, // 0
F16_F8_F32_F32_F16, // 1
};
if(argc != 16)
......@@ -58,6 +59,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const int StrideD1 = std::stoi(argv[14]);
const int StrideE = std::stoi(argv[15]);
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
......@@ -94,7 +96,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_multiply_impl<ADataType,
bool pass = ck::profiler::profile_gemm_multiply_add_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
......@@ -130,6 +132,16 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
layout == MatrixLayout::MK_KN_MN_MN_MN)
{
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
layout == MatrixLayout::MK_NK_MN_MN_MN)
{
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment