Commit 30d8a363 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Enable f16/f8 mixed precision

parent 844b215d
...@@ -35,13 +35,17 @@ __global__ void ...@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
...@@ -61,7 +65,8 @@ __global__ void ...@@ -61,7 +65,8 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -102,7 +107,8 @@ template <typename ALayout, ...@@ -102,7 +107,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
bool EnableMixedPrecision = false>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -122,6 +128,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -122,6 +128,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// convert f8 to f16 if mixed precision enabled
using FloatAAdjusted =
conditional_t<EnableMixedPrecision && is_same_v<FloatA, ck::f8_t>, ck::half_t, FloatA>;
using FloatBAdjusted =
conditional_t<EnableMixedPrecision && is_same_v<FloatB, ck::f8_t>, ck::half_t, FloatB>;
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
...@@ -463,8 +475,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -463,8 +475,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -479,8 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -479,8 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
...@@ -541,8 +553,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -541,8 +553,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned * sizeof(FloatAAdjusted) +
sizeof(FloatAB), b_block_space_size_aligned * sizeof(FloatBAdjusted)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -676,8 +688,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -676,8 +688,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Problem& problem) const Problem& problem)
...@@ -743,8 +755,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -743,8 +755,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, FloatAAdjusted,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -774,8 +786,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,8 +786,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<BK0Number, NPerBlock, BK1Number>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, FloatBAdjusted,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -796,6 +808,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -796,6 +808,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// check if A and B types are same
static_assert(std::is_same<FloatAAdjusted, FloatBAdjusted>::value,
"The A and B types do not match!");
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
...@@ -805,11 +821,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -805,11 +821,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<FloatAAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAAdjusted,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -827,10 +843,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -827,10 +843,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAAdjusted*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment