Unverified Commit 3f299c33 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/ggemm_dl_instances

parents 507d793a 091570f5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -58,16 +58,16 @@ __global__ void
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
......@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
......@@ -18,60 +18,23 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
static_cast<void*>(p_shared_block),
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared));
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -79,13 +42,13 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CMNGridDesc,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
......@@ -126,10 +89,238 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
struct Argument : public ck::tensor_operation::device::BaseArgument
{
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0;
index_t k_batch;
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
p_c_grid(p_c_grid_),
M(M_),
N(N_),
K(K_),
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", "
<< "KB:" << k_batch << "}" << std::endl;
}
};
__host__ __device__ static auto CalculateGridSize(const Argument& karg)
{
return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch);
}
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
index_t MPad,
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t KPad)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
index_t NPad,
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t KPad)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
......@@ -178,45 +369,68 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
c_block_size * sizeof(FloatC));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
const Block2CTileMap& block_2_ctile_map)
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
return false;
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
return false;
}
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
return false;
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
return false;
}
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return false;
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
else
{
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
......@@ -224,8 +438,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return has_main_k0_block_loop;
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
......@@ -242,10 +457,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
......@@ -262,24 +478,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -289,26 +506,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0];
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
......@@ -444,7 +651,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
......@@ -647,7 +853,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
......@@ -716,6 +922,48 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
});
}
}
template <typename Layout>
struct LStr
{
static std::string Get() { return ""; }
};
template <>
struct LStr<ck::tensor_layout::gemm::RowMajor>
{
static std::string Get() { return "R"; }
};
template <>
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
{
static std::string Get() { return "C"; }
};
static std::string GetTypeString()
{
auto str = std::stringstream();
// clang-format off
str << "GemmXdlSplitKCShuffle_"
<< getGemmSpecializationString(GemmSpec) << "_"
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< "_"
<< "B" << BlockSize << "_"
<< "Vec" << ABlockTransferSrcScalarPerVector << "x"
<< BBlockTransferSrcScalarPerVector << "x"
<< CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
<< MPerBlock << "x"
<< NPerBlock << "x"
<< K0PerBlock << "x"
<< K1 ;
// clang-format on
return str.str();
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -1008,6 +1008,60 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
template <typename T>
struct NumericLimits
{
......
......@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
static inline __host__ double tanh(double x) { return std::tanh(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
......@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
static inline __device__ double tanh(double x) { return ::tanh(x); };
} // namespace math
} // namespace ck
......@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
......@@ -102,6 +103,10 @@ template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Mul_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
......@@ -109,6 +114,10 @@ template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Mul2_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory;
......
......@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -49,6 +49,21 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -80,6 +95,22 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -26,7 +26,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
......@@ -35,14 +36,22 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>
// clang-format on
>;
......
......@@ -25,6 +25,7 @@ using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GK = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
......@@ -32,17 +33,25 @@ using I32_Tuple = ck::Tuple<int32_t>;
using F32_Tuple = ck::Tuple<float>;
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
// perlayer
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
// bias + perlayer
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
using Add_Mul_TanH_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<TanH>;
// perchannel
using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<PassThrough>;
using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Relu>;
// bias + perchannel
using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<PassThrough>;
using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Relu>;
using Add_Mul2_TanH_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<TanH>;
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
......@@ -76,6 +76,42 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
ConvFwd1x1S1P0,
4>{});
}
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_TanH_Mul_Clamp>>>& instances)
{
// dl
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwdDefault,
4>{});
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwd1x1P0,
4>{});
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwd1x1S1P0,
4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -76,6 +76,43 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
ConvFwd1x1S1P0,
4>{});
}
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_TanH_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwdDefault,
4>{});
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwd1x1P0,
4>{});
add_device_operation_instances(instances,
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwd1x1S1P0,
4>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -74,6 +74,41 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
ConvFwd1x1S1P0,
8>{});
}
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_TanH_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwdDefault,
8>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwd1x1P0,
8>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_TanH_Mul_Clamp,
ConvFwd1x1S1P0,
8>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -76,6 +76,43 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
ConvFwd1x1S1P0,
8>{});
}
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_TanH_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwdDefault,
8>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwd1x1P0,
8>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
I32_Tuple,
Add_Mul_TanH_Mul_Clamp,
ConvFwd1x1S1P0,
8>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -8,8 +8,8 @@ MY_PROJECT_SOURCE=$1
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -mcumode \
-mno-wavefrontsize64 -Wno-gnu-line-marker -save-temps=$PWD" \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment