Unverified Commit a8629a98 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_v2r3_kpad_fix

parents 8dc713ea 94bfa502
...@@ -27,6 +27,12 @@ struct PassThrough ...@@ -27,6 +27,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -69,18 +75,36 @@ struct PassThrough ...@@ -69,18 +75,36 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<float, half_t>(float& y, const half_t& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const __host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{ {
y = type_convert<int8_t>(x); y = type_convert<int8_t>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const __host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
...@@ -89,6 +113,7 @@ struct PassThrough ...@@ -89,6 +113,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -118,6 +143,7 @@ struct PassThrough ...@@ -118,6 +143,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -146,6 +172,7 @@ struct ConvertBF16RTN ...@@ -146,6 +172,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -162,6 +189,7 @@ struct ConvertF8SR ...@@ -162,6 +189,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -412,14 +440,19 @@ struct Swish ...@@ -412,14 +440,19 @@ struct Swish
{ {
Swish(float beta = 1.0f) : beta_(beta) {} Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<T, ck::half_t>::value, is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x)); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; float beta_ = 1.0f;
......
...@@ -7,11 +7,9 @@ ...@@ -7,11 +7,9 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...@@ -19,8 +17,6 @@ ...@@ -19,8 +17,6 @@
namespace ck { namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -29,8 +25,7 @@ template <typename GridwiseGemm, ...@@ -29,8 +25,7 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop>
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -43,13 +38,6 @@ __global__ void ...@@ -43,13 +38,6 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -100,8 +88,7 @@ template <index_t BlockSize, ...@@ -100,8 +88,7 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3 struct GridwiseGemmDl_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
...@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>(); BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane(
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
const auto c_grid_desc_m_n = amd_wave_read_first_lane(
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n);
#else
ignore = karg;
#endif
}
template <index_t BlockSize,
typename ABDataType,
typename AccDataType,
typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerDpp,
index_t NPerDpp,
index_t AK1Value,
index_t BK1Value,
index_t MDppPerWave,
index_t NDppPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
static constexpr auto max_lds_align = math::lcm(AK1, BK1);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); }
__host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); }
// Argument
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
AK0{CalculateAK0(K)},
BK0{CalculateBK0(K)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t AK0;
index_t BK0;
};
// Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument
{
__host__ Argument(const ABDataType* p_a_grid_,
const ABDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const ABDataType* p_a_grid;
const ABDataType* p_b_grid;
CDataType* p_c_grid;
};
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<AK0PerBlock>{}, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + 1>{} * AK1, AK1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<AK0PerBlock>{}, Number<MPerBlock>{}, AK1), max_lds_align);
}
}();
return a_block_desc_ak0_m_ak1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<BK0PerBlock>{}, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + 1>{} * BK1, BK1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<BK0PerBlock>{}, Number<NPerBlock>{}, BK1), max_lds_align);
}
}();
return b_block_desc_bk0_n_bk1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
}
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value,
"Wrong! AK1 must be known at the time of compilation.");
static_assert(is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
"Wrong! BK1 must be known at the time of compilation.");
static_assert(
MPerBlock % (MPerDpp * MDppPerWave) == 0,
"Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
static_assert(
NPerBlock % (NPerDpp * NDppPerWave) == 0,
"Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
static_assert(
KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
"Invalid tuning parameters! AK1Value must be divisible by "
"ABlockTransferDstScalarPerVector_K1");
static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
"Invalid tuning parameters! BK1Value must be divisible by "
"BBlockTransferDstScalarPerVector_K1");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if(problem.K % KPerBlock != 0)
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = problem.K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const auto num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n)
{
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), DppSelector<ABDataType, MPerDpp, NPerDpp>::selected_dpp.k_per_dpp);
using BlockwiseGemm =
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2<BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerDpp,
NPerDpp,
MDppPerWave,
NDppPerWave,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
}
static constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
__device__ static auto
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__device__ static auto
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_pass_through_transform(N),
make_unmerge_transform(make_tuple(BK0, BK1Value))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
__device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto block_2_ctile_map =
Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0),
c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[AK0PerBlock, MPerBlock] is in LDS
// b_mtx[BK0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), DppSelector<ABDataType, MPerDpp, NPerDpp>::selected_dpp.k_per_dpp);
auto blockwise_gemm =
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2<BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerDpp,
NPerDpp,
MDppPerWave,
NDppPerWave,
KPack>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0);
// gridwise GEMM pipeline
const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0);
// (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// output: register to global memory
{
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, MPerThread, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_grid_desc_m0_n0_m1_n1_m2_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_n2,
c_grid_buf);
}
}
};
} // namespace ck
...@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
......
...@@ -29,7 +29,9 @@ namespace ck { ...@@ -29,7 +29,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename ABDataType, // FIXME: don't assume A/B have same datatype template <typename ADataType,
typename BDataType,
typename ComputeType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
...@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType), sizeof(ComputeType),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// check tensor size: cannot be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{ {
return false; return false;
...@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEElementwiseOperation_, typename CDEElementwiseOperation_,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence<1, AK0PerBlock, MPerBlock, AK1>, Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ADataType,
ABDataTypeAdjusted, ComputeType,
decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_grid_desc_kbatch_ak0_m_ak1),
decltype(a_block_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence<1, BK0PerBlock, NPerBlock, BK1>, Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, BDataType,
ABDataTypeAdjusted, ComputeType,
decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_grid_desc_kbatch_bk0_n_bk1),
decltype(b_block_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataTypeAdjusted, ComputeType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared), static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
...@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const index_t KBatch, const index_t KBatch,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_); const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
using DsGridDesc_M_N = using DsGridDesc_M_N =
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector,
typename Block2ETileMap>
struct GridwiseImageToColumn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__device__ static void Run(const InputGridDesc& in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map)
{
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
Tuple<InputDataType>,
Tuple<OutputDataType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<MPerBlock, KPerBlock>,
ThreadClusterLengths,
Sequence<0, 1>,
Sequence<0, 1>,
I1,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{
in_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
out_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
}
__host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc,
const OutputGridDesc& out_grid_desc)
{
if(in_grid_desc.GetLength(I0) % MPerBlock != 0 ||
in_grid_desc.GetLength(I1) % KPerBlock != 0)
return false;
if(out_grid_desc.GetLength(I0) % MPerBlock != 0 ||
out_grid_desc.GetLength(I1) % KPerBlock != 0)
return false;
return true;
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert dst_vector.template AsType<DstData>()(i) = v;
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = v;
}); });
}); });
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace ck {
enum struct DppInstr
{
dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template <DppInstr instr>
struct dpp_type;
template <>
struct dpp_type<DppInstr::dpp8_f16_32x8x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 32;
static constexpr index_t n_per_wave = 8;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 16;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
static constexpr auto GetDpp();
template <>
static constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
{
static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
static_assert(selected_dpp.k_per_dpp % 2 == 0);
static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
constexpr index_t num_dpp_c_elems =
selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
if constexpr(selected_dpp.share_a)
{
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
}
else
{
static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
}
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
struct DppGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__host__ __device__ constexpr DppGemm()
{
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
}
__device__ static constexpr index_t GetRegSizePerDpp()
{
return MPerDpp * NPerDpp / dpp_instr.wave_size;
}
template <class ADataType, class BDataType, class CDataType>
__device__ void
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
"base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneIdInWave()
{
return get_thread_local_1d_id() % dpp_instr.wave_size;
}
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
__device__ static auto GetLaneIdInLaneGroup()
{
return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
}
__device__ static auto GetLaneGroupIdInWave()
{
return GetLaneIdInWave() / dpp_instr.lanegroup_size;
}
__device__ static auto GetDppOpIdx()
{
const auto lanegroupId = GetLaneGroupIdInWave();
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
make_multi_index(lanegroupId));
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
return make_tuple(m_dpp_idx, n_dpp_idx);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
{
const auto laneId = get_thread_local_1d_id();
const auto wave_row = laneId / dpp_instr.n_per_wave;
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
return make_tuple(0, m_idx % dpp_instr.m_per_wave);
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
{
const auto laneId = get_thread_local_1d_id();
return make_tuple(0, laneId % dpp_instr.n_per_wave);
}
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto dpp_op_idx = GetDppOpIdx();
const auto m_dpp_op_idx = dpp_op_idx[I0];
const auto n_dpp_op_idx = dpp_op_idx[I1];
index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
return CIndex{m_offset, n_offset};
}
static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
static constexpr auto dpp_instr = dpp.selected_dpp;
static constexpr auto K0PerDpp = 1;
static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
__host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
{
return make_tuple(Number<dpp_instr.m_per_thread>{}, Number<dpp_instr.n_per_thread>{});
}
};
} // namespace ck
...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -640,6 +642,7 @@ struct MfmaSelector ...@@ -640,6 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -651,6 +654,7 @@ struct MfmaSelector ...@@ -651,6 +654,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -852,7 +856,11 @@ struct XdlopsGemm ...@@ -852,7 +856,11 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -164,6 +164,7 @@ template < ...@@ -164,6 +164,7 @@ template <
index_t BK1, index_t BK1,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM, bool DoPadGemmM,
bool DoPadGemmN> bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1 struct TransformConvBwdDataToGemm_v1
...@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
...@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc, out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc, out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
...@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmn_padded_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock), make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{}); Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc, wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
...@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmk_gemm_padded_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock), make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{}); Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc, wei_gemmk_gemmn_padded_grid_desc,
make_tuple( make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_unmerge_transform(make_tuple(BK0, BK1)), make_pass_through_transform(
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))), wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
} }
#endif #endif
#endif
} }
// buffer_load requires: // buffer_load requires:
...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
} }
#endif #endif
} }
......
...@@ -5,17 +5,63 @@ ...@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp" #include "ck/utility/inner_product_dpp8.hpp"
namespace ck { namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers. template <class ABDataType>
constexpr index_t lane_group_size = 8; struct dpp_datatypes;
__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } template <>
__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } struct dpp_datatypes<half_t>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using a_dtype = half_t;
using b_dtype = half_t;
using c_dtype = float;
static constexpr index_t k_per_instr = 2;
};
template <index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
class BaseInputType,
class AVecDataType,
class BVecDataType,
class CVecDataType,
bool ShareA>
struct DppLanegroupGemm
{
using datatypes_conf = dpp_datatypes<BaseInputType>;
using ADataType = typename datatypes_conf::a_dtype;
using BDataType = typename datatypes_conf::b_dtype;
using CDataType = typename datatypes_conf::c_dtype;
__device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
{
constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
const vector_type<ADataType, KPerThread> a_vector{a_vec};
const vector_type<BDataType, KPerThread> b_vector{b_vec};
static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) {
float c = c_vec.template AsType<CDataType>()(c_idx);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr index_t source_lane = c_idx;
static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
ck::dpp8::
inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
a_k_vec, b_k_vec, c);
});
c_vec.template AsType<CDataType>()(c_idx) = c;
});
}
};
} // namespace dpp8 } // namespace dpp8
......
...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,7 +12,12 @@ using half_t = _Float16; ...@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t; #if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -143,14 +148,24 @@ struct scalar_type<int4_t> ...@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
//
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t> ...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000 static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); } __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); } __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); } __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T>
struct NumericUtils
{
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck } // namespace ck
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -22,53 +25,38 @@ namespace ck::utils { ...@@ -22,53 +25,38 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int out_mant = NumericUtils<Y>::mant;
// fp8 exponent/mantissa layout // original type exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int f8_mant = 3; constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
if constexpr(is_float) head = x_bitwise & NumericUtils<X>::head_mask;
{ mantissa = x_bitwise & NumericUtils<X>::mant_mask;
head = x_bitwise & 0xFF800000; exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
mantissa = x_bitwise & 0x7FFFFF; sign = head >> (in_exp + in_mant);
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant); uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
} uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
else if constexpr(is_half) constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0 // check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant; mantissa += 1 << in_mant;
// apply random number if needed // apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant)) if(mantissa >= (2 << in_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (type_mant - f8_mant); mantissa >>= (in_mant - out_mant);
// check negative exponent // check negative exponent
if(exponent <= 0) if(exponent <= 0)
...@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << f8_mant) - 1; mantissa = (1 << out_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0 // check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << f8_mant) - 1; mantissa &= (1 << out_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int in_exp = NumericUtils<X>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int in_mant = NumericUtils<X>::mant;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes // prepare the codes
constexpr uint8_t nan_code = 0x80; constexpr X nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0; Y Inf, NegInf, NaN, Neg0;
if constexpr(is_half) using T_bitwise = typename NumericUtils<Y>::bitwise_type;
{
constexpr uint16_t ihInf = 0x7C00; constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr uint16_t ihNegInf = 0xFC00; constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr uint16_t ihNaN = 0x7C01; constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr uint16_t ihNeg0 = 0x8000; constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
} Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
else if constexpr(is_float)
{ // check if x is 0.0
constexpr uint32_t ifInf = 0x7F800000; if(x == 0)
constexpr uint32_t ifNegInf = 0xFF800000; return static_cast<Y>(0);
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; T_bitwise retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == nan_code) if(x == nan_code)
return fNaN; return NaN;
} }
else else
{ {
if(x == nan_code) if(x == nan_code)
return fNeg0; return Neg0;
if(exponent == ((1 << f8_exp) - 1)) if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
} }
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); exponent++;
mantissa <<= sh; while(mantissa < (1 << in_mant))
mantissa &= ((1 << f8_mant) - 1); {
exponent += 1 - sh; mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << in_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << type_mant; mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype // check datatypes
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8."); static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng); return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ Y cast_from_f8(X x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0 return run_cast_from_f8<X, Y, negative_zero_nan>(x);
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck::utils } // namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "amd_gemm_dpp.hpp" #include "amd_gemm_dpp.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type_convert.hpp" #include "type_convert.hpp"
...@@ -10,6 +11,9 @@ namespace ck { ...@@ -10,6 +11,9 @@ namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers.
constexpr index_t lane_group_size = 8;
template <int SrcLaneIdx> template <int SrcLaneIdx>
__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment