Commit 54d73870 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

added bf16@int8 version

parent f76c0072
......@@ -19,14 +19,13 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp>
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -103,7 +102,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
static constexpr index_t K0PerBlock = KPerBlock / AK1;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using WorkspaceDataType = float;
// First stage GridwiseGEMM kernel.
......@@ -153,10 +152,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
PipelineVer,
ComputeDataType>;
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// indexy 1,3 -> MPerBlock, NPerBlock || podzielone przez MPerBlock -> NPerThread
template <typename ELay>
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
......@@ -192,7 +188,7 @@ template <typename ELay>
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
......@@ -219,33 +215,45 @@ template <typename ELay>
static constexpr auto MakeElementwiseInputSequence()
{
return generate_sequence_v2(
[&]([[maybe_unused]] auto i) constexpr { return Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{}; },
Number<NumDTensor+1>{});
[&]([[maybe_unused]] auto i) constexpr {
return Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
},
Number<NumDTensor + 1>{});
}
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using DsGridPointer = decltype(MakeDsGridPointer());
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using DsGridPointer = decltype(MakeDsGridPointer());
using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
using CDDataTypes = decltype(concat_tuple(ck::Tuple<WorkspaceDataType*>{}, DsGridPointer{}));
using CDDataTypes = decltype(concat_tuple(ck::Tuple<WorkspaceDataType*>{}, DsGridPointer{}));
using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence());
static constexpr index_t ClusterLengthMPerBlock = CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock = CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using GridwiseElementwise = GridwiseElementwise_2D<CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
CDEElementwiseOperation,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
ElementwiseInputSequence,
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>>;
static constexpr index_t ClusterLengthMPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GridwiseElementwise =
GridwiseElementwise<CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2ETileMapKSplit,
CDEElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<0, 1>,
ElementwiseInputSequence,
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>,
true>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
......@@ -318,7 +326,6 @@ template <typename ELay>
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
// group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size");
......@@ -451,7 +458,7 @@ template <typename ELay>
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
group_grid_size_[i] = grid_size_grp;
group_grid_size_[i] = grid_size_grp;
karg.KPadded = k_padded;
karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH;
......@@ -460,14 +467,15 @@ template <typename ELay>
gemm_kernel_args_[i].block_end_ = block_end;
#if DEBUG_LOG
index_t tiles = (block_end - block_start) / K_BATCH;
index_t tiles = (block_end - block_start) / K_BATCH;
std::cout << "block_start: " << block_start << "\n"
<< "block_end: " << block_end << "\n"
<< "tiles: " << tiles << std::endl << std::endl;
<< "tiles: " << tiles << std::endl
<< std::endl;
std::cout << "KPadded: " << karg.KPadded << std::endl
<< "K0Padded: " << karg.K0Padded << std::endl
<< "KBatch: " << karg.k_batch << std::endl
<< "K0Padded: " << karg.K0Padded << std::endl
<< "KBatch: " << karg.k_batch << std::endl
<< "grid_size_: " << karg.KPadded << std::endl;
#endif
}
......@@ -476,16 +484,13 @@ template <typename ELay>
void UpdateEPointers()
{
// set-up each group E pointer to it's designated workspace memory.
float* p_workspace = reinterpret_cast<float*>(p_workspace_);
std::size_t offset = 0;
WorkspaceDataType* p_workspace = reinterpret_cast<WorkspaceDataType*>(p_workspace_);
std::size_t offset = 0;
// TODO: per group e-ptr memory alignment (128B)?
for(auto& arg : gemm_kernel_args_)
{
arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
// TODO: a co z paddingiem, layout'em w pamięci ??
// czy jest jakiś deskryptor ?
offset += tiles * MPerBlock * NPerBlock;
#if DEBUG_LOG
std::cout << "block_start: " << arg.block_start_ << "\n"
......@@ -499,7 +504,7 @@ template <typename ELay>
std::size_t GetWorkspaceSizeBytes() const
{
std::size_t size_bytes{0};
// TODO: per group e-ptr memory alignment (128B)?
for(const auto& arg : gemm_kernel_args_)
{
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
......@@ -534,10 +539,8 @@ template <typename ELay>
std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
std::vector<DsGridPointer> ds_grid_pointer_;
std::vector<void *> e_ptrs_;
std::vector<void*> e_ptrs_;
};
// Invoker
......@@ -729,13 +732,19 @@ template <typename ELay>
BElementwiseOperation,
PassThrough>;
const auto elementwise_kernel = kernel_elementwise_2d<GridwiseElementwise,
CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
CDEElementwiseOperation>;
return LaunchKernel(gemm_kernel, elementwise_kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config);
const auto elementwise_kernel = kernel_elementwise<GridwiseElementwise,
CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2ETileMapKSplit,
CDEElementwiseOperation>;
return LaunchKernel(gemm_kernel,
elementwise_kernel,
arg,
dev_gemm_args,
dev_gemm_workspace,
stream_config);
}
template <typename KernelFunction, typename KernelFunction2>
......@@ -767,20 +776,24 @@ template <typename ELay>
arg.b_element_op_,
PassThrough{});
// launch elementwise kernels.
for(int i=0; i < arg.group_count_; ++i) {
time += launch_and_time_kernel(stream_config,
elementwise_kernel,
dim3(arg.group_grid_size_[i]), // chyba group_grid_size <<< tak zmienic na group_grid_size[i]
dim3(BlockSize),
0,
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), arg.elementwise_d_grid_descs_m_n_[i]),
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]),
arg.cde_element_op_,
ClusterLengthMPerBlock, // num_threads_m
ClusterLengthNPerBlock); // num_threads_n
// Elementwise kernels
for(int i = 0; i < arg.group_count_; ++i)
{
time += launch_and_time_kernel(
stream_config,
elementwise_kernel,
dim3(arg.group_grid_size_[i]),
dim3(BlockSize),
0,
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
arg.elementwise_d_grid_descs_m_n_[i]),
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]),
Block2ETileMapKSplit{
arg.elementwise_c_grid_descs_m_n_[i], B2E_M01, arg.K_BATCH},
arg.cde_element_op_);
}
return time;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -159,6 +159,19 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
......@@ -203,7 +216,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
......@@ -242,6 +256,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
......
......@@ -10,4 +10,5 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future
// a[m, k] * b[k, n] = e[m, n]
using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>
// clang-format on
>;
void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -22,13 +22,10 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future
......@@ -39,36 +36,26 @@ using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
>;
//static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, // err
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, // err
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>
// clang-format on
>;
......@@ -85,10 +72,8 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
instances, device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
......@@ -12,17 +12,12 @@
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16, // 0
BF16_INT8_BF16 // 1
};
#define OP_NAME "grouped_gemm_two_stage"
......@@ -52,9 +47,8 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8@fp6; 5: f16@f8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< "arg2: data type (0: fp16; 1: bf16@int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
......@@ -81,16 +75,17 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
auto StrideAs = argToIntArray(argv[11]); //a: mk b: kn, c: mn: stride a =
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
auto StrideAs = argToIntArray(argv[11]);
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
for(size_t i=0; i<Ms.size(); ++i) {
for(size_t i = 0; i < Ms.size(); ++i)
{
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
......@@ -108,46 +103,48 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment