Unverified Commit 832b69cb authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into grouped_conv_3d_layout_fix

parents 53130727 a35456a3
...@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) ...@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[4]); StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]); StrideB = std::stoi(argv[5]);
StrideD0 = std::stoi(argv[6]); StrideD0 = std::stoi(argv[6]);
StrideE = std::stoi(argv[8]); StrideE = std::stoi(argv[7]);
} }
else else
{ {
......
...@@ -72,7 +72,7 @@ int main(int argc, char* argv[]) ...@@ -72,7 +72,7 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[4]); StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]); StrideB = std::stoi(argv[5]);
StrideE = std::stoi(argv[8]); StrideE = std::stoi(argv[6]);
} }
else else
{ {
......
...@@ -121,7 +121,8 @@ using DeviceOpInstance = ...@@ -121,7 +121,8 @@ using DeviceOpInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
1>;
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*/
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
// TODO: Enable for gfx90a after complier fix
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::Run(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11,
e_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmMultipleD_Dl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
});
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
};
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
ADataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
EGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using DsGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
using EGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
K_(K),
Batch_(Batch),
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
e_grid_desc_m0_m10_m11_n0_n10_n11_{},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
});
e_grid_desc_m_n_ =
DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
ds_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_);
e_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
index_t K_;
// Batch
index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
// for calculating batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmMultipleD_Dl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0),
arg.e_grid_desc_m_n_.GetLength(I1)) *
arg.Batch_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel =
kernel_gemm_dl_multiple_d<GridwiseGemm,
ADataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
DeviceOp::EGridDesc_M0_M10_M11_N0_N10_N11,
ComputePtrOffsetOfStridedBatch,
DefaultBlock2CTileMap,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.e_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// TODO: Enable for gfx90a after complier fix
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.e_grid_desc_m_n_);
return pass;
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
Batch,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor>& StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
Batch,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmMultipleD_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -197,7 +197,8 @@ template <index_t NumDimG, ...@@ -197,7 +197,8 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
...@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>; using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer // D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]); p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
}); });
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
...@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<index_t> b_nz_kz_strides_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_; std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
...@@ -772,6 +780,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -772,6 +780,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return false; return false;
} }
for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
{
std::cout << "first" << std::endl;
return false;
}
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
{
std::cout << "second" << std::endl;
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -80,7 +80,8 @@ template <typename FloatAB, ...@@ -80,7 +80,8 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> int D0sTransferSrcScalarPerVector = 4,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
I1, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
I1, // MWaveId m1, // MWaveId
I1, // NWaveId n1, // NWaveId
I1, // MPerXdl m2, // MPerXdl
I1, // NGroupNum n2, // NGroupNum
I1, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
auto d0s_thread_buf = generate_tuple( auto d0s_thread_buf = generate_tuple(
...@@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple( auto d0s_threadwise_copy = generate_tuple(
[&](auto i) { [&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>; using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
...@@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
D0DataType, D0DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>, Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9,
n4, D0sTransferSrcScalarPerVector,
1, 1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
...@@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// multiple d // multiple d
if constexpr(NumD0Tensor) if constexpr(NumD0Tensor)
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto mr) { static_assert(NXdlPerWave == n0);
static_for<0, NXdlPerWave, 1>{}([&](auto nr) { static_assert(MXdlPerWave == m0);
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run( d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_grid_buf[i],
d0s_grid_buf[i], d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), d0s_thread_buf(i));
d0s_thread_buf(i)); });
}); static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// get reference to src data
static_for<0, n4, 1>{}([&](auto i) { const auto src_data_refs = generate_tie(
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( // return type should be lvalue
make_tuple(mr, nr, groupid, i)); [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
Number<NumD0Tensor>{});
// get reference to src data
const auto src_data_refs = generate_tie( // get reference to dst data
// return type should be lvalue auto dst_data_refs = generate_tie(
[&](auto iSrc) -> const auto& { // return type should be lvalue
return d0s_thread_buf[iSrc][i]; [&](auto) -> auto& { return acc_thread_buf(i); },
}, Number<2>{});
Number<NumD0Tensor>{});
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
}); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}); });
} }
else else
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
...@@ -14,29 +15,83 @@ ...@@ -14,29 +15,83 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <unsigned Size> template <unsigned SizeInBytes>
struct get_unsigned_int; struct get_carrier;
template <> template <>
struct get_unsigned_int<1> struct get_carrier<1>
{ {
using type = uint8_t; using type = uint8_t;
}; };
template <> template <>
struct get_unsigned_int<2> struct get_carrier<2>
{ {
using type = uint16_t; using type = uint16_t;
}; };
template <> template <>
struct get_unsigned_int<4> struct get_carrier<3>
{
using type = class carrier
{
using value_type = uint32_t;
std::array<std::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n()
template <typename InputIterator, typename Size, typename OutputIterator>
__device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to)
{
if(0 < size)
{
*to = *from;
++to;
for(Size count = 1; count < size; ++count)
{
*to = *++from;
++to;
}
}
return to;
}
// method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
}
public:
__device__ carrier& operator=(value_type value) noexcept
{
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
return *this;
}
__device__ operator value_type() const noexcept
{
std::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result);
return *reinterpret_cast<const value_type*>(result);
}
};
};
static_assert(sizeof(get_carrier<3>::type) == 3);
template <>
struct get_carrier<4>
{ {
using type = uint32_t; using type = uint32_t;
}; };
template <unsigned Size> template <unsigned SizeInBytes>
using get_unsigned_int_t = typename get_unsigned_int<Size>::type; using get_carrier_t = typename get_carrier<SizeInBytes>::type;
} // namespace detail } // namespace detail
...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize)
{ {
using Sgpr = detail::get_unsigned_int_t<SgprSize>; using Sgpr = detail::get_carrier_t<SgprSize>;
*reinterpret_cast<Sgpr*>(to_obj + offset) = *reinterpret_cast<Sgpr*>(to_obj + offset) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset)); amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset));
...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
if constexpr(0 < RemainedSize) if constexpr(0 < RemainedSize)
{ {
using Carrier = detail::get_unsigned_int_t<RemainedSize>; using Carrier = detail::get_carrier_t<RemainedSize>;
*reinterpret_cast<Carrier>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( *reinterpret_cast<Carrier*>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane(
*reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary)); *reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary));
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmMultiD<
ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceBatchedGemmMultiD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(op_ptrs);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// // MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// // MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=32, NPerBlock=32
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=32, NPerBlock=32
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances = std::tuple<
// clang-format off
// ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// MPerBlock=128, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=128, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=128
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=32, NPerBlock=32
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=16, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// MPerBlock=64, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=16, NPerBlock=16
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=8, NPerBlock=64
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// MPerBlock=64, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
// MPerBlock=8, NPerBlock=8
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment