Unverified Commit e1a5137e authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into transpose_5d

parents eb57178d 718065eb
...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>; C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>; C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
FloatC, // typename Src2Data, FloatC, // typename Src2Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype( decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector,
typename Block2ETileMap>
struct GridwiseImageToColumn
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__device__ static void Run(const InputGridDesc& in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map)
{
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
Tuple<InputDataType>,
Tuple<OutputDataType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<MPerBlock, KPerBlock>,
ThreadClusterLengths,
Sequence<0, 1>,
Sequence<0, 1>,
I1,
ScalarPerVector,
Sequence<true>,
Sequence<true>>{
in_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
out_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
}
__host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc,
const OutputGridDesc& out_grid_desc)
{
if(in_grid_desc.GetLength(I0) % MPerBlock != 0 ||
in_grid_desc.GetLength(I1) % KPerBlock != 0)
return false;
if(out_grid_desc.GetLength(I0) % MPerBlock != 0 ||
out_grid_desc.GetLength(I1) % KPerBlock != 0)
return false;
return true;
}
};
} // namespace ck
...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{}))); Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype( using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{}))); make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford = using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>; ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
......
...@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st ...@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = kRightmostBlock < K_BlockTileSize int kPerThread = kRightmostBlock < K_BlockTileSize
? 0 ? 0
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert dst_vector.template AsType<DstData>()(i) = v;
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = v;
}); });
}); });
} }
......
...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
......
...@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation // apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) { static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]); element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert // apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v); dst_vector_container.template AsType<DstData>()(i) = v;
}); });
const bool is_dst_valid = const bool is_dst_valid =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace ck {
enum struct DppInstr
{
dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template <DppInstr instr>
struct dpp_type;
template <>
struct dpp_type<DppInstr::dpp8_f16_32x8x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 32;
static constexpr index_t n_per_wave = 8;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 16;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 8;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 8;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
static constexpr auto GetDpp();
template <>
static constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
{
static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
static_assert(selected_dpp.k_per_dpp % 2 == 0);
static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
constexpr index_t num_dpp_c_elems =
selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
if constexpr(selected_dpp.share_a)
{
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
}
else
{
static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
selected_dpp.lanegroup_size);
static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
}
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
struct DppGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__host__ __device__ constexpr DppGemm()
{
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
}
__device__ static constexpr index_t GetRegSizePerDpp()
{
return MPerDpp * NPerDpp / dpp_instr.wave_size;
}
template <class ADataType, class BDataType, class CDataType>
__device__ void
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
"base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneIdInWave()
{
return get_thread_local_1d_id() % dpp_instr.wave_size;
}
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
__device__ static auto GetLaneIdInLaneGroup()
{
return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
}
__device__ static auto GetLaneGroupIdInWave()
{
return GetLaneIdInWave() / dpp_instr.lanegroup_size;
}
__device__ static auto GetDppOpIdx()
{
const auto lanegroupId = GetLaneGroupIdInWave();
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
make_multi_index(lanegroupId));
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
return make_tuple(m_dpp_idx, n_dpp_idx);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
{
const auto laneId = get_thread_local_1d_id();
const auto wave_row = laneId / dpp_instr.n_per_wave;
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
return make_tuple(0, m_idx % dpp_instr.m_per_wave);
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
{
const auto laneId = get_thread_local_1d_id();
return make_tuple(0, laneId % dpp_instr.n_per_wave);
}
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto dpp_op_idx = GetDppOpIdx();
const auto m_dpp_op_idx = dpp_op_idx[I0];
const auto n_dpp_op_idx = dpp_op_idx[I1];
index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
return CIndex{m_offset, n_offset};
}
static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
static constexpr auto dpp_instr = dpp.selected_dpp;
static constexpr auto K0PerDpp = 1;
static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
__host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
{
return make_tuple(Number<dpp_instr.m_per_thread>{}, Number<dpp_instr.n_per_thread>{});
}
};
} // namespace ck
...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -640,6 +642,7 @@ struct MfmaSelector ...@@ -640,6 +642,7 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -651,6 +654,7 @@ struct MfmaSelector ...@@ -651,6 +654,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -852,7 +856,11 @@ struct XdlopsGemm ...@@ -852,7 +856,11 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -164,6 +164,7 @@ template < ...@@ -164,6 +164,7 @@ template <
index_t BK1, index_t BK1,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM, bool DoPadGemmM,
bool DoPadGemmN> bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1 struct TransformConvBwdDataToGemm_v1
...@@ -236,8 +237,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -236,8 +237,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc = const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>( make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
...@@ -247,6 +246,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -247,6 +246,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t AK0 = math::integer_divide_ceil(K, AK1);
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
...@@ -332,7 +333,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -332,7 +333,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -340,7 +341,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -340,7 +341,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -352,21 +353,30 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -352,21 +353,30 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5, 6>{})); Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence<7>{})); Sequence<7>{}));
const auto const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -437,22 +447,31 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -437,22 +447,31 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7, 8>{})); Sequence<7>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -505,8 +524,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -505,8 +524,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1;
// assume packed // assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d // k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C); const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
...@@ -515,6 +532,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -515,6 +532,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t BK0 = math::integer_divide_ceil(K, BK1);
// B: weight tensor // B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
...@@ -566,43 +585,49 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -566,43 +585,49 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor( wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_pass_through_transform(K),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_freeze_transform(i_ytilde),
make_freeze_transform(i_ytilde), make_freeze_transform(i_xtilde),
make_freeze_transform(i_xtilde), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1>{},
Sequence<1>{}, Sequence<3>{},
Sequence<3>{}, Sequence<2>{},
Sequence<2>{}, Sequence<4>{},
Sequence<4>{}, Sequence<5>{}),
Sequence<5>{}), make_tuple(Sequence<0>{},
make_tuple(Sequence<0, 1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<3>{}));
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C), make_pass_through_transform(C)),
make_pass_through_transform(BK1)), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(GemmKPerBlock, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{}); const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
} }
...@@ -631,10 +656,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -631,10 +656,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence<5, 6>{}, Sequence<5, 6>{},
Sequence<7>{})); Sequence<7>{}));
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
...@@ -650,33 +675,39 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -650,33 +675,39 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7>{}), Sequence<7>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<5>{})); Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple( make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)), make_pass_through_transform(C)),
make_pass_through_transform(C), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_pass_through_transform(BK1)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(GemmKPerBlock, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
} }
else else
{ {
......
...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
} }
else else
{ {
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
} }
#endif #endif
#endif
} }
// buffer_load requires: // buffer_load requires:
...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = auto tmp =
...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
else else
{ {
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
} }
#endif
} }
#endif #endif
} }
......
...@@ -5,17 +5,63 @@ ...@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp" #include "ck/utility/inner_product_dpp8.hpp"
namespace ck { namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers. template <class ABDataType>
constexpr index_t lane_group_size = 8; struct dpp_datatypes;
__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } template <>
__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } struct dpp_datatypes<half_t>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using a_dtype = half_t;
using b_dtype = half_t;
using c_dtype = float;
static constexpr index_t k_per_instr = 2;
};
template <index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
class BaseInputType,
class AVecDataType,
class BVecDataType,
class CVecDataType,
bool ShareA>
struct DppLanegroupGemm
{
using datatypes_conf = dpp_datatypes<BaseInputType>;
using ADataType = typename datatypes_conf::a_dtype;
using BDataType = typename datatypes_conf::b_dtype;
using CDataType = typename datatypes_conf::c_dtype;
__device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
{
constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
const vector_type<ADataType, KPerThread> a_vector{a_vec};
const vector_type<BDataType, KPerThread> b_vector{b_vec};
static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) {
float c = c_vec.template AsType<CDataType>()(c_idx);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr index_t source_lane = c_idx;
static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
ck::dpp8::
inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
a_k_vec, b_k_vec, c);
});
c_vec.template AsType<CDataType>()(c_idx) = c;
});
}
};
} // namespace dpp8 } // namespace dpp8
......
...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,7 +12,12 @@ using half_t = _Float16; ...@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t; #if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -143,14 +148,24 @@ struct scalar_type<int4_t> ...@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
//
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t> ...@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000 static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111 static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); } __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); } __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); } __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T>
struct NumericUtils
{
};
template <>
struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
} // namespace ck } // namespace ck
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -22,53 +25,38 @@ namespace ck::utils { ...@@ -22,53 +25,38 @@ namespace ck::utils {
namespace { namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int out_mant = NumericUtils<Y>::mant;
// fp8 exponent/mantissa layout // original type exponent/mantissa layout
constexpr int f8_exp = 4; constexpr int in_exp = NumericUtils<X>::exp;
constexpr int f8_mant = 3; constexpr int in_mant = NumericUtils<X>::mant;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent; int exponent;
uint32_t head, mantissa, sign; uint32_t head, mantissa, sign;
// nan code is same for float and half // nan code is same for float and half
constexpr uint8_t nan_code = 0x80; constexpr Y nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
// convert to bitwise // convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type using T_bitwise = typename NumericUtils<X>::bitwise_type;
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype // unpack the input, depends on datatype
if constexpr(is_float) head = x_bitwise & NumericUtils<X>::head_mask;
{ mantissa = x_bitwise & NumericUtils<X>::mant_mask;
head = x_bitwise & 0xFF800000; exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
mantissa = x_bitwise & 0x7FFFFF; sign = head >> (in_exp + in_mant);
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant); uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
} uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
else if constexpr(is_half) constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
...@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0); return signed_inf + (mantissa != 0 ? 1 : 0);
} }
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0 // check if x is 0.0
if(x_bitwise == 0) if(x_bitwise == 0)
return 0; return 0;
exponent -= exp_low_cutoff - 1; exponent -= exp_low_cutoff - 1;
if(exponent <= 0) if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant; mantissa += 1 << in_mant;
// apply random number if needed // apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask; mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant)) if(mantissa >= (2 << in_mant))
{ {
mantissa >>= 1; mantissa >>= 1;
exponent++; exponent++;
} }
mantissa >>= (type_mant - f8_mant); mantissa >>= (in_mant - out_mant);
// check negative exponent // check negative exponent
if(exponent <= 0) if(exponent <= 0)
...@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{ {
if(clip) if(clip)
{ {
mantissa = (1 << f8_mant) - 1; mantissa = (1 << out_mant) - 1;
exponent = max_exp; exponent = max_exp;
} }
else else
...@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) ...@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0 // check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0) if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << f8_mant) - 1; mantissa &= (1 << out_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x) __host__ __device__ Y run_cast_from_f8(X x)
{ {
// check data type // fp8/bf8 exponent/mantissa layout
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr int in_exp = NumericUtils<X>::exp;
constexpr bool is_float = std::is_same<T, float>::value; constexpr int in_mant = NumericUtils<X>::mant;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout // resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8; constexpr int out_exp = NumericUtils<Y>::exp;
constexpr int type_mant = is_half ? 10 : 23; constexpr int out_mant = NumericUtils<Y>::mant;
// prepare the codes // prepare the codes
constexpr uint8_t nan_code = 0x80; constexpr X nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0; Y Inf, NegInf, NaN, Neg0;
if constexpr(is_half) using T_bitwise = typename NumericUtils<Y>::bitwise_type;
{
constexpr uint16_t ihInf = 0x7C00; constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf;
constexpr uint16_t ihNegInf = 0xFC00; constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
constexpr uint16_t ihNaN = 0x7C01; constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN;
constexpr uint16_t ihNeg0 = 0x8000; constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf)); Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN)); NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0)); NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
} Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
else if constexpr(is_float)
{ // check if x is 0.0
constexpr uint32_t ifInf = 0x7F800000; if(x == 0)
constexpr uint32_t ifNegInf = 0xFF800000; return static_cast<Y>(0);
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input // unpack the input
uint32_t sign = x >> (f8_exp + f8_mant); uint32_t sign = x >> (in_exp + in_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1); uint32_t mantissa = x & ((1 << in_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant; int exponent = (x & 0x7F) >> in_mant;
constexpr int exp_low_cutoff = constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval; T_bitwise retval;
if constexpr(negative_zero_nan) if constexpr(negative_zero_nan)
{ {
if(x == nan_code) if(x == nan_code)
return fNaN; return NaN;
} }
else else
{ {
if(x == nan_code) if(x == nan_code)
return fNeg0; return Neg0;
if(exponent == ((1 << f8_exp) - 1)) if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
{
retval = x;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
} }
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); exponent++;
mantissa <<= sh; while(mantissa < (1 << in_mant))
mantissa &= ((1 << f8_mant) - 1); {
exponent += 1 - sh; mantissa <<= 1;
exponent--;
}
mantissa &= ((1 << in_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant; mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << type_mant; mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval)); return *(reinterpret_cast<const Y*>(&retval));
} }
} // namespace } // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch> template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype // check datatypes
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8."); static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng); return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
} }
template <typename T, bool negative_zero_nan> template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x) __host__ __device__ Y cast_from_f8(X x)
{ {
// check datatype // check datatype
constexpr bool is_half = std::is_same<T, half_t>::value; constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported."); static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0 return run_cast_from_f8<X, Y, negative_zero_nan>(x);
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
} }
} // namespace ck::utils } // namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "amd_gemm_dpp.hpp" #include "amd_gemm_dpp.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type_convert.hpp" #include "type_convert.hpp"
...@@ -10,6 +11,9 @@ namespace ck { ...@@ -10,6 +11,9 @@ namespace ck {
namespace dpp8 { namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers.
constexpr index_t lane_group_size = 8;
template <int SrcLaneIdx> template <int SrcLaneIdx>
__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
} // namespace ck
...@@ -116,7 +116,15 @@ struct Max ...@@ -116,7 +116,15 @@ struct Max
template <typename T> template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
}
}; };
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
...@@ -138,6 +146,15 @@ struct Max ...@@ -138,6 +146,15 @@ struct Max
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -152,6 +169,18 @@ struct Max ...@@ -152,6 +169,18 @@ struct Max
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
}; };
struct Min struct Min
...@@ -159,6 +188,15 @@ struct Min ...@@ -159,6 +188,15 @@ struct Min
template <typename T> template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Max();
}
return NumericLimits<T>::Max(); return NumericLimits<T>::Max();
}; };
...@@ -181,6 +219,15 @@ struct Min ...@@ -181,6 +219,15 @@ struct Min
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -195,6 +242,18 @@ struct Min ...@@ -195,6 +242,18 @@ struct Min
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
}; };
struct AMax struct AMax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment