Commit 8a370fbb authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into group_conv

parents d8fdd226 85978e02
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmDesc
{
ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_, stride_C_;
std::vector<ck::index_t> stride_Ds_;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
...@@ -10,9 +11,10 @@ ...@@ -10,9 +11,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -21,22 +23,20 @@ namespace tensor_operation { ...@@ -21,22 +23,20 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc, typename GemmDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -65,42 +65,49 @@ __global__ void ...@@ -65,42 +65,49 @@ __global__ void
} }
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr, gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr, gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].c_ptr, gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared, p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_); gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif
} }
template <typename ADataType, template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType, typename BDataType,
typename CDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename CShuffleDataType,
typename BLayout, typename DsDataType,
typename CLayout, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t K0PerBlock, ck::index_t KPerBlock,
ck::index_t K1, ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL, ck::index_t MPerXDL,
ck::index_t NPerXDL, ck::index_t NPerXDL,
ck::index_t MXdlPerWave, ck::index_t MXdlPerWave,
...@@ -111,163 +118,139 @@ template <typename ADataType, ...@@ -111,163 +118,139 @@ template <typename ADataType,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CThreadTransferDstScalarPerVector, index_t CShuffleNXdlPerWavePerShuffle,
ck::index_t NumPrefetch = 1, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t MaxGroupCount = 10> index_t CDEBlockTransferScalarPerVector_NPerBlock,
struct DeviceGroupedGemmXdl LoopScheduler LoopSched = make_default_loop_scheduler()>
: public DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedGemm_Xdl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto K1Number = Number<K1>{}; static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
assert(K % K1 == 0); const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
} }
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
} }
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_m_n = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
{ }
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor( static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
c_grid_desc_m_n, const std::array<index_t, NumDTensor>& NRaws,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), const std::array<index_t, NumDTensor>& DsStride)
make_tuple(Sequence<0>{}, Sequence<1>{}), {
make_tuple(Sequence<0>{}, Sequence<1>{})); return generate_tuple(
} [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CShuffleDataType,
InMemoryDataOperationEnum::Set, DsDataType,
AGridDesc_K0_M_K1, EDataType,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumPrefetch, // NumGemmKPrefetchStage
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
AK1,
BK1,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -277,7 +260,7 @@ struct DeviceGroupedGemmXdl ...@@ -277,7 +260,7 @@ struct DeviceGroupedGemmXdl
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -285,169 +268,218 @@ struct DeviceGroupedGemmXdl ...@@ -285,169 +268,218 @@ struct DeviceGroupedGemmXdl
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CShuffleMXdlPerWavePerShuffle,
CThreadTransferSrcDstVectorDim, CShuffleNXdlPerWavePerShuffle,
CThreadTransferDstScalarPerVector, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
NumPrefetch>; CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
struct GroupedGemmBlock2CTileMap
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
struct GroupedGemmBlock2ETileMap
{ {
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; using UnderlyingBlock2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
static_assert( static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)), std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
typename GridwiseGemm::DefaultBlock2CTileMap>::value, typename GridwiseGemm::DefaultBlock2ETileMap>::value,
"Wrong! Should be the same type name"); "Wrong! Should be the same type name");
GroupedGemmBlock2CTileMap()
GroupedGemmBlock2ETileMap()
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
BlockStart_ = -1; BlockStart_ = -1;
} }
GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
index_t M01,
index_t N01,
ck::index_t BlockStart)
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
BlockStart_ = BlockStart; BlockStart_ = BlockStart;
} }
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
return block_2_ctile_map_.CalculateBottomIndex( return block_2_etile_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] - BlockStart_)); make_multi_index(idx_top[I0] - BlockStart_));
} }
// it's actually E-Tile
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const const CTileDim& c_tile_dim) const
{ {
return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
} }
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
{ {
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n); return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
} }
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
ck::index_t BlockStart_; ck::index_t BlockStart_;
}; };
struct GemmDescKernelArg struct GemmBiasTransKernelArg
{ {
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; // pointers
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; const ADataType* a_ptr_;
CGridDesc_M_N c_grid_desc_m_n_; const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 EDataType* e_ptr_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
// tensor descriptors for problem definiton
GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
const ADataType* a_ptr; DsGridDesc_M_N ds_grid_desc_m_n_;
const BDataType* b_ptr; EGridDesc_M_N e_grid_desc_m_n_;
CDataType* c_ptr;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
GroupedGemmBlock2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_; ck::index_t BlockStart_, BlockEnd_;
}; };
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(std::vector<const void*>& p_a, Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_b, std::vector<const void*>& p_Bs,
std::vector<void*>& p_c, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<GemmShape>& gemm_shapes, std::vector<void*>& p_Es,
index_t M01, std::vector<GemmDesc>& gemm_descs,
index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation c_element_op)
: M01_{M01}, : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{ {
grid_size_ = 0; grid_size_ = 0;
p_workspace_ = nullptr; group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
{ {
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
} }
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_shapes[i].M; const index_t M = gemm_descs[i].M_;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_shapes[i].K; const index_t K = gemm_descs[i].K_;
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
// pointer
typename GridwiseGemm::DsGridPointer p_ds_grid{};
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
});
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
DsGridDesc_M_N ds_grid_desc_m_n;
const index_t StrideA = gemm_shapes[i].StrideA; static_for<0, NumDTensor, 1>{}([&](auto j) {
const index_t StrideB = gemm_shapes[i].StrideB; using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
const index_t StrideC = gemm_shapes[i].StrideC;
const auto a_grid_desc_k0_m_k1_ = ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); M, N, gemm_descs[i].stride_Ds_[j]);
const auto b_grid_desc_k0_n_k1_ = });
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ = const auto e_grid_desc_m_n =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const index_t grid_size_grp = const index_t grid_size_grp =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
.block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
const auto grouped_gemm_block_2_ctile_map_ = // block-to-e-tile map
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_k0_n_k1_, b_grid_desc_n_k,
c_grid_desc_m_n_, ds_grid_desc_m_n,
grouped_gemm_block_2_ctile_map_)) e_grid_desc_m_n,
block_2_etile_map))
{ {
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = // tensor descriptors for block/thread-wise copy
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_, GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
b_grid_desc_k0_n_k1_, static_cast<const BDataType*>(p_Bs[i]),
c_grid_desc_m_n_, p_ds_grid,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, static_cast<EDataType*>(p_Es[i]),
grouped_gemm_block_2_ctile_map_, a_grid_desc_m_k,
static_cast<const ADataType*>(p_a[i]), b_grid_desc_n_k,
static_cast<const BDataType*>(p_b[i]), ds_grid_desc_m_n,
static_cast<CDataType*>(p_c[i]), e_grid_desc_m_n,
BlockStart, a_grid_desc_ak0_m_ak1,
BlockEnd}); b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map,
BlockStart,
BlockEnd});
} }
} }
} }
// private: // private:
index_t M01_;
index_t N01_;
index_t group_count_; index_t group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -455,7 +487,7 @@ struct DeviceGroupedGemmXdl ...@@ -455,7 +487,7 @@ struct DeviceGroupedGemmXdl
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGroupedGemmXdl::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
...@@ -463,33 +495,39 @@ struct DeviceGroupedGemmXdl ...@@ -463,33 +495,39 @@ struct DeviceGroupedGemmXdl
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
<< ", "
std::cout << ", arg.b_grid_desc_k0_n_k1_{" << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " << "}";
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
std::cout << ", arg.c_grid_desc_m_n_{ " << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", " << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}" << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
<< "}";
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl; << std::endl;
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_, arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_)) arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) * const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2); arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{ {
...@@ -500,58 +538,39 @@ struct DeviceGroupedGemmXdl ...@@ -500,58 +538,39 @@ struct DeviceGroupedGemmXdl
hipGetErrorString( hipGetErrorString(
hipMemcpy(arg.p_workspace_, hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GemmBiasTransKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = ave_time = launch_kernel(integral_constant<bool, true>{});
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
true>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
} }
else else
{ {
const auto kernel = ave_time = launch_kernel(integral_constant<bool, false>{});
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
} }
return ave_time; return ave_time;
...@@ -565,18 +584,14 @@ struct DeviceGroupedGemmXdl ...@@ -565,18 +584,14 @@ struct DeviceGroupedGemmXdl
} }
}; };
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
{
return false; return false;
else }
return true;
return true;
} }
// polymorphic // polymorphic
...@@ -585,31 +600,34 @@ struct DeviceGroupedGemmXdl ...@@ -585,31 +600,34 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(std::vector<const void*>& p_a, static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_b, std::vector<const void*>& p_Bs,
std::vector<void*>& p_c, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<GemmShape> gemm_shapes, std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a, std::unique_ptr<BaseArgument>
std::vector<const void*>& p_b, MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<void*>& p_c, std::vector<const void*>& p_Bs,
std::vector<GemmShape>& gemm_shapes, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
AElementwiseOperation a_element_op, std::vector<void*>& p_Es,
BElementwiseOperation b_element_op, std::vector<GemmDesc>& gemm_descs,
CElementwiseOperation c_element_op, AElementwiseOperation a_element_op,
index_t /* KBatch */ = 1) override BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
} }
// polymorphic // polymorphic
...@@ -624,13 +642,14 @@ struct DeviceGroupedGemmXdl ...@@ -624,13 +642,14 @@ struct DeviceGroupedGemmXdl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedGemmXdl" str << "DeviceGroupedGemm_Xdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << KPerBlock << ", "
<< K1 << ", " << AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
...@@ -643,7 +662,7 @@ struct DeviceGroupedGemmXdl ...@@ -643,7 +662,7 @@ struct DeviceGroupedGemmXdl
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg); return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DsType = Tuple<>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -31,9 +31,8 @@ check_err(const std::vector<T>& out, ...@@ -31,9 +31,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -50,9 +49,8 @@ check_err(const std::vector<T>& out, ...@@ -50,9 +49,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << out[i] << " != " << ref[i] << std::endl << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -74,9 +72,8 @@ check_err(const std::vector<T>& out, ...@@ -74,9 +72,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -96,9 +93,8 @@ check_err(const std::vector<T>& out, ...@@ -96,9 +93,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << o << " != " << r << std::endl << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -120,9 +116,8 @@ check_err(const std::vector<T>& out, ...@@ -120,9 +116,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -141,9 +136,8 @@ check_err(const std::vector<T>& out, ...@@ -141,9 +136,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << o << " != " << r << std::endl << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -165,9 +159,8 @@ check_err(const std::vector<T>& out, ...@@ -165,9 +159,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -187,9 +180,9 @@ check_err(const std::vector<T>& out, ...@@ -187,9 +180,9 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i]) std::cout << msg << " out[" << i << "] != ref[" << i
<< " != " << static_cast<int>(ref[i]) << std::endl << "]: " << static_cast<int>(out[i]) << " != " << static_cast<int>(ref[i])
<< msg << std::endl; << std::endl;
} }
res = false; res = false;
} }
......
...@@ -433,52 +433,3 @@ struct Tensor ...@@ -433,52 +433,3 @@ struct Tensor
HostTensorDescriptor mDesc; HostTensorDescriptor mDesc;
std::vector<T> mData; std::vector<T> mData;
}; };
#if 1
// FIXME: remove
template <typename T>
float check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
float l1_error = 0;
float linf_error = -1;
float linf_rel_error = -1;
float linf_ref_value = 0, linf_result_value = 0;
float linf_rel_ref_value = 0, linf_rel_result_value = 0;
constexpr float eps = 1e-10;
for(std::size_t i = 0; i < ref.mData.size(); ++i)
{
float ref_v = ck::type_convert<float>(ref.mData[i]);
float result_v = ck::type_convert<float>(result.mData[i]);
float diff = std::abs(ref_v - result_v);
float rel_diff = diff / std::max(std::abs(ref_v), eps);
l1_error += diff;
if(linf_error < diff)
{
linf_error = diff;
linf_ref_value = ref_v;
linf_result_value = result_v;
}
if(linf_rel_error < rel_diff)
{
linf_rel_error = rel_diff;
linf_rel_ref_value = ref_v;
linf_rel_result_value = result_v;
}
}
std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl;
std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref "
<< linf_ref_value << ", result " << linf_result_value << std::endl;
std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref "
<< linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl;
return linf_error;
}
#endif
...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
......
...@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
......
...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -30,32 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -30,32 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsType| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
......
...@@ -23,53 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,53 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
// irregular tile size
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); bool c_error =
float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); ck::utils::check_err(c_g_m_n_host_result.mData, c_g_m_n_device_result.mData);
float d1_error = check_error(d1_g_m_host_result, d1_g_m_device_result); bool d0_error =
ck::utils::check_err(d0_g_m_host_result.mData, d0_g_m_device_result.mData);
pass = pass && (c_error < 1E-6); bool d1_error =
pass = pass && (d0_error < 1E-6); ck::utils::check_err(d1_g_m_host_result.mData, d1_g_m_device_result.mData);
pass = pass && (d1_error < 1E-6);
pass = pass && (c_error == true);
pass = pass && (d0_error == true);
pass = pass && (d1_error == true);
if(do_log) if(do_log)
{ {
......
...@@ -159,7 +159,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -159,7 +159,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device Conv instances // profile device Conv instances
bool pass = true; bool all_pass = true;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -215,8 +215,15 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -215,8 +215,15 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{ {
wei_device_buf.FromDevice(weight_device_result.mData.data()); wei_device_buf.FromDevice(weight_device_result.mData.data());
pass = pass & bool pass = ck::utils::check_err(wei_k_c_y_x_host_result.mData,
ck::utils::check_err(weight_device_result.mData, weight_host_result.mData); wei_k_c_y_x_device_result.mData);
if(!pass)
{
std::cout << "Fail info:" << conv_ptr->GetTypeString() << std::endl;
}
all_pass &= pass;
if(do_log) if(do_log)
{ {
...@@ -248,7 +255,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -248,7 +255,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return all_pass;
} }
} // namespace profiler } // namespace profiler
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename WeiLayout>
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename OutLayout>
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
void get_device_conv_bwd_data_op_ptr(
InDataType, WeiDataType, OutDataType, std::vector<DeviceConvBwdDataNoOpPtr>&, int)
{
std::cout << "can not find device conv bwd data" << std::endl;
exit(1);
}
template <>
void get_device_conv_bwd_data_op_ptr(
F32, F32, F32, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break;
default: break;
}
}
template <>
void get_device_conv_bwd_data_op_ptr(
F16, F16, F16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break;
default: break;
}
}
template <>
void get_device_conv_bwd_data_op_ptr(
BF16, BF16, BF16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break;
default: break;
}
}
template <>
void get_device_conv_bwd_data_op_ptr(
INT8, INT8, INT8, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
break;
default: break;
}
}
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
float max_diff = 1e-6;
for(std::size_t i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
return false;
}
}
return true;
}
template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{
std::cout << "[";
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
{
std::cout << "[";
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
{
std::cout << "[";
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
{
std::cout << "[";
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
{
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
template <int NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool profile_convnd_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
std::vector<std::size_t> input_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(C)};
input_dims.insert(
std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(K), static_cast<std::size_t>(C)};
filter_dims.insert(std::end(filter_dims),
std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths));
std::vector<std::size_t> output_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(K)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> input_host_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<InDataType> input_device_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<WeiDataType> weights(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<OutDataType> output(
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "weights: " << weights.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero
in_device_buf.SetZero();
if(do_verification)
{
auto RunReference = [&](auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input_host_result,
weights,
output,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
};
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NDimSpatial>();
RunReference(ref_conv);
}
// add device Conv instances
std::vector<DeviceConvBwdDataNoOpPtr> conv_ptrs;
get_device_conv_bwd_data_op_ptr(
InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
std::string best_conv_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device Conv instances
bool success = true;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string conv_name = conv_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop =
ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
std::size_t num_btype =
ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
if(tflops > best_tflops)
{
best_conv_name = conv_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
in_device_buf.FromDevice(input_device_result.mData.data());
if(!check_out(input_host_result, input_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
}
success = ck::utils::check_err(input_host_result.mData, input_device_result.mData);
if(do_log)
{
std::cout << "in : ";
show_data_nhwc_layout(output);
std::cout << std::endl;
std::cout << "wei: ";
show_data_nhwc_layout(weights);
std::cout << std::endl;
std::cout << "out_host : ";
show_data_nhwc_layout(input_host_result);
std::cout << std::endl;
std::cout << "out_device: ";
show_data_nhwc_layout(input_device_result);
std::cout << std::endl;
}
}
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
return success;
}
} // namespace profiler
} // namespace ck
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp"
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DeviceConvndBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
using DeviceConvndBwdWeightNoOpPtr =
ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr;
template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename WeiLayout>
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename OutLayout>
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
case 2: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
case 1: {
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
void get_device_conv_bwd_weight_op_ptr(
InDataType, WeiDataType, OutDataType, std::vector<DeviceConvndBwdWeightNoOpPtr>&, int)
{
std::cout << "can not find device conv bwd weight" << std::endl;
exit(1);
}
template <>
void get_device_conv_bwd_weight_op_ptr(
F32, F32, F32, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break;
default: break;
}
}
template <>
void get_device_conv_bwd_weight_op_ptr(
F16, F16, F16, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break;
default: break;
}
}
template <>
void get_device_conv_bwd_weight_op_ptr(
BF16, BF16, BF16, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break;
default: break;
}
}
template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{
std::cout << "[";
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
{
std::cout << "[";
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
{
std::cout << "[";
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
{
std::cout << "[";
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
{
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
template <int NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool profile_convnd_bwd_weight_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t split_k)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
std::vector<std::size_t> input_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(C)};
input_dims.insert(
std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(K), static_cast<std::size_t>(C)};
filter_dims.insert(std::end(filter_dims),
std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths));
std::vector<std::size_t> output_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(K)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> input(get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<WeiDataType> weights_host_result(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<WeiDataType> weights_device_result(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<OutDataType> output(
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weights: " << weights_host_result.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
input.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
output.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
break;
default:
input.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
output.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
in_device_buf.ToDevice(input.mData.data());
out_device_buf.ToDevice(output.mData.data());
// reset input to zero
wei_device_buf.SetZero();
if(do_verification)
{
auto RunReference = [&](auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input,
weights_host_result,
output,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
};
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NDimSpatial>();
RunReference(ref_conv);
}
// add device Conv instances
std::vector<DeviceConvndBwdWeightNoOpPtr> conv_ptrs;
get_device_conv_bwd_weight_op_ptr(
InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
std::string best_conv_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device Conv instances
bool success = true;
for(auto& conv_ptr : conv_ptrs)
{
// using atomic, so need to reset input, setzero is done in invoker
// if(split_k > 1)
//{
// wei_device_buf.SetZero();
//}
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
if(!conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
continue;
}
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
std::string conv_name = conv_ptr->GetTypeString();
float ave_time = 0;
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
{
// alloc work space
size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get());
if(bwd_weight_workspace_size <= 0)
{
printf("wrong work space size\n");
exit(1);
}
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero();
conv_ptr->SetWorkSpacePointer(argument_ptr.get(),
wei_work_space_device_buf.GetDeviceBuffer());
ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
}
else
{
ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
}
std::size_t flop =
ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
if(tflops > best_tflops)
{
best_conv_name = conv_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
wei_device_buf.FromDevice(weights_device_result.mData.data());
success = ck::utils::check_err(weights_host_result.mData, weights_device_result.mData);
if(success == false)
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
}
if(do_log)
{
std::cout << "in : ";
show_data_nhwc_layout(output);
std::cout << std::endl;
std::cout << "wei: ";
show_data_nhwc_layout(weights_host_result);
std::cout << std::endl;
std::cout << "out : ";
show_data_nhwc_layout(input);
std::cout << std::endl;
std::cout << "wei_device: ";
show_data_nhwc_layout(weights_device_result);
std::cout << std::endl;
}
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
return success;
}
} // namespace profiler
} // namespace ck
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -17,41 +19,17 @@ ...@@ -17,41 +19,17 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
void profile_grouped_gemm_impl(int do_verification, bool profile_grouped_gemm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
...@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs) const std::vector<int>& StrideCs)
{ {
bool pass = true;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
...@@ -86,7 +67,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -86,7 +67,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<ADataType>> a_m_k; std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n; std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_device_results; std::vector<Tensor<EDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -96,7 +77,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -96,7 +77,7 @@ void profile_grouped_gemm_impl(int do_verification,
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back( c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
...@@ -115,7 +96,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -115,7 +96,7 @@ void profile_grouped_gemm_impl(int do_verification,
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread); c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<EDataType>{}, num_thread);
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -145,9 +126,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -145,9 +126,9 @@ void profile_grouped_gemm_impl(int do_verification,
p_b.reserve(group_count); p_b.reserve(group_count);
p_c.reserve(group_count); p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_shapes.reserve(group_count); gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -157,56 +138,34 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -157,56 +138,34 @@ void profile_grouped_gemm_impl(int do_verification,
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>( c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); sizeof(EDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
} }
// add device GEMM instances using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
std::vector<ck::tensor_operation::device::instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs; BLayout,
CLayout,
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && ADataType,
is_same<CDataType, half_t>::value) BDataType,
{ ck::Tuple<>,
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && EDataType,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && AElementOp,
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) BElementOp,
{ CElementOp>;
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
} DeviceOp>::GetInstances();
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && if(op_ptrs.size() <= 0)
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
if(gemm_ptrs.size() <= 0)
{ {
throw std::runtime_error("wrong! no device GEMM instance found"); throw std::runtime_error("wrong! no device GEMM instance found");
} }
...@@ -216,14 +175,17 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -216,14 +175,17 @@ void profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a, gemm_ptr->MakeArgumentPointer(p_a,
p_b, p_b,
p_ds,
p_c, p_c,
gemm_shapes, gemm_descs,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
...@@ -242,12 +204,12 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -242,12 +204,12 @@ void profile_grouped_gemm_impl(int do_verification,
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i]; sizeof(EDataType) * Ms[i] * Ns[i];
} }
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -266,18 +228,18 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -266,18 +228,18 @@ void profile_grouped_gemm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Tensor<CDataType> c_m_n_host_result( Tensor<EDataType> c_m_n_host_result(
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
...@@ -294,7 +256,8 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -294,7 +256,8 @@ void profile_grouped_gemm_impl(int do_verification,
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData); pass = pass && ck::utils::check_err(c_m_n_device_results[i].mData,
c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -319,6 +282,8 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -319,6 +282,8 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
} // namespace profiler } // namespace profiler
} // namespace profiler } // namespace profiler
......
File mode changed from 100644 to 100755
...@@ -85,7 +85,6 @@ def parse_logfile(logfile): ...@@ -85,7 +85,6 @@ def parse_logfile(logfile):
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
print("len(lst)=",len(lst),"lst:",lst)
if len(lst)>=37: #the line is complete if len(lst)>=37: #the line is complete
tests.append(glue.join(lst[5:30])) tests.append(glue.join(lst[5:30]))
kernels.append(glue.join(lst[37:])) kernels.append(glue.join(lst[37:]))
...@@ -293,4 +292,4 @@ def main(): ...@@ -293,4 +292,4 @@ def main():
return regression return regression
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
#!/bin/bash
#
# in order to run this script you'd need the following python packages:
pip3 install --upgrade pip
pip3 install sqlalchemy pymysql pandas sshtunnel
# you would also need to set up some environment variables in order to
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#process results
gpu_arch=$1
python3 process_perf_data.py perf_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N265_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N4_"$gpu_arch".log
\ No newline at end of file
#!/bin/bash
#
# in order to run this script you'd need the following python packages:
pip3 install --upgrade pip
pip3 install sqlalchemy pymysql pandas sshtunnel
# you would also need to set up some environment variables in order to
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#process results
gpu_arch=$1
python3 process_perf_data.py perf_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N265_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N4_"$gpu_arch".log
python3 process_perf_data.py perf_batched_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_grouped_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_fwd_conv_"$gpu_arch".log
python3 process_perf_data.py perf_bwd_conv_"$gpu_arch".log
python3 process_perf_data.py perf_fusion_"$gpu_arch".log
python3 process_perf_data.py perf_reduction_"$gpu_arch".log
\ No newline at end of file
...@@ -11,26 +11,34 @@ INIT=$5 ...@@ -11,26 +11,34 @@ INIT=$5
LOG=$6 LOG=$6
REPEAT=$7 REPEAT=$7
######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount OP=$1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 8 DATATYPE=$2
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 8 LAYOUT=$3
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 4 VERIFY=$4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 2 INIT=$5
LOG=$6
REPEAT=$7
######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 -1 -1 -1 2
\ No newline at end of file \ No newline at end of file
#!/bin/bash
## GPU visibility
export HIP_VISIBLE_DEVICES=0
DRIVER="../build/bin/ckProfiler"
OP=$1
DATATYPE=$2
LAYOUT=$3
VERIFY=$4
INIT=$5
LOG=$6
TIME=$7
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 -1 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 0 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1000 1000 1000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2000 2000 2000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4000 4000 4000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8000 8000 8000 -1 -1 0 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 1056 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 2080 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4128 4128 4128 4128 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8224 8224 8224 8224 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 1088 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 2112 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4160 4160 4160 4160 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8256 8256 8256 8256 1 1
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment