".github/vscode:/vscode.git/clone" did not exist on "b50c818623049e533ca5f8d07dfe5ead3ca08d9a"
Unverified Commit b39f07f1 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Implement MI200 FP16 Denorm fix inside threadwise copy (#191)



* start convert

* using buffer load

* add kernel transfer fun

* using asm for transfer

* add transpose_half_to_bhalf_2x2

* add TypeMap struct

* add LDSDataType to v2r3 and v2r4r2

* change convert fun name

* remove asm in half transfer to bhalf

* fix bug for type_convert

* cshuffle_v1 add LDSDataType

* add ldstype for gridegemm v2r4

* add lds datat ype to v3r1 2 3

* init complete

* fix function name

* remove comments

* format

* fix for merge develop
Co-authored-by: default avatarltqin <letaoqin@amd.com>
parent 3956085d
......@@ -45,12 +45,21 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
#if 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>;
#endif
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -110,6 +110,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -180,7 +182,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
sizeof(LDSDataType),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -366,7 +368,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -397,7 +399,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -425,12 +427,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -447,10 +450,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<LDSDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<LDSDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
......@@ -190,6 +190,8 @@ template <index_t BlockSize,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -261,7 +263,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(LDSDataType);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -380,7 +382,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -486,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -517,7 +519,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -548,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -565,10 +567,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<LDSDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<LDSDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
......@@ -38,10 +38,11 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
using LDSDataType = typename TypeMap<FloatAB>::type;
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(LDSDataType);
__shared__ FloatAB p_shared_block[shared_block_size];
__shared__ LDSDataType p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -108,6 +109,7 @@ template <index_t BlockSize,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -161,7 +163,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
return (a_block_space_size + b_block_space_size) * sizeof(LDSDataType);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -263,7 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -320,7 +322,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
LDSDataType* __restrict__ p_shared_block,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
......@@ -428,7 +430,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -458,7 +460,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -488,7 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -504,8 +506,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
LDSDataType* p_a_block = p_shared_block;
LDSDataType* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
......
......@@ -40,10 +40,11 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
using LDSDataType = typename TypeMap<FloatAB>::type;
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(LDSDataType);
__shared__ FloatAB p_shared_block[shared_block_size];
__shared__ LDSDataType p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -111,6 +112,7 @@ template <index_t BlockSize,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -167,7 +169,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
return math::max((a_block_space_size + b_block_space_size) * sizeof(LDSDataType),
c_block_size * sizeof(FloatC));
}
......@@ -308,7 +310,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
LDSDataType* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -417,7 +419,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -447,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -477,7 +479,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -493,8 +495,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
LDSDataType* p_a_block = p_shared_block;
LDSDataType* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
......@@ -574,7 +576,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
static_cast<FloatC*>(static_cast<void*>(p_shared_block)),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
......
......@@ -116,6 +116,7 @@ template <
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -216,7 +217,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
sizeof(LDSDataType),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -421,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -452,7 +453,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -480,12 +481,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t k_pack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t k_pack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -502,10 +504,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<LDSDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<LDSDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
......@@ -122,6 +122,7 @@ template <
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -221,7 +222,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
sizeof(LDSDataType),
c_block_size * sizeof(FloatC));
}
......@@ -442,7 +443,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -473,7 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -504,7 +505,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -521,10 +522,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<LDSDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<LDSDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
......@@ -131,6 +131,7 @@ template <
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
using LDSDataType = typename TypeMap<FloatAB>::type;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -230,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
sizeof(LDSDataType),
c_block_size * sizeof(FloatC));
}
......@@ -463,7 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
LDSDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -523,7 +524,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
LDSDataType,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -540,10 +541,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<LDSDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<LDSDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
......@@ -278,7 +278,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
(is_same<half_t, remove_cvref_t<DstData>>::value ||
is_same<bhalf_t, remove_cvref_t<DstData>>::value) &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
......@@ -340,8 +341,36 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// do data transpose
// TODO type_convert is not used yet!!!!!
if constexpr(is_same<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>::value)
{
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
}
else
{
transpose_convert_vectors<SrcData,
DstData,
DstScalarPerVector,
SrcScalarPerVector>{}(src_vector_refs,
dst_vector_refs);
}
});
}
else if constexpr(SrcVectorDim == DstVectorDim && SrcScalarPerVector % 2 == 0 &&
DstScalarPerVector % 2 == 0 &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value)
{
auto NewSliceLengths = SliceLengths{}.template Modify(
Number<SrcVectorDim>{}, Number<SliceLengths{}[SrcVectorDim] / 2>{});
auto VectorStep = SliceLengths{} / NewSliceLengths;
static_ford<decltype(NewSliceLengths)>{}([&](auto idx) {
// convert from SrcData to DstData here
auto nidx = idx * VectorStep;
auto vhalf =
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<half2_t>(nidx);
dst_thread_scratch_.template SetAsType<bhalf2_t>(nidx,
type_convert<bhalf2_t>(vhalf));
});
}
else
......
......@@ -284,7 +284,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
(is_same<half_t, remove_cvref_t<DstData>>::value ||
is_same<bhalf_t, remove_cvref_t<DstData>>::value) &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
{
// each transpose does
......@@ -343,8 +344,27 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
transpose_convert_vectors<SrcData,
DstData,
DstScalarPerVector,
SrcScalarPerVector>{}(src_vector_refs, dst_vector_refs);
});
}
else if constexpr(SrcVectorDim == DstVectorDim && SrcScalarPerVector % 2 == 0 &&
DstScalarPerVector % 2 == 0 &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value)
{
auto NewSliceLengths = SliceLengths{}.template Modify(
Number<SrcVectorDim>{}, Number<SliceLengths{}[SrcVectorDim] / 2>{});
auto VectorStep = SliceLengths{} / NewSliceLengths;
static_ford<decltype(NewSliceLengths)>{}([&](auto idx) {
// convert from SrcData to DstData here
auto nidx = idx * VectorStep;
auto vhalf =
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<half2_t>(nidx);
dst_thread_scratch_.template SetAsType<bhalf2_t>(nidx,
type_convert<bhalf2_t>(vhalf));
});
}
else
......
......@@ -992,6 +992,113 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x)
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bf16
template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
union
{
float fp32;
uint32_t int32;
} u = {static_cast<float>(x)};
return uint16_t(u.int32 >> 16);
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, half2_t>(half2_t x)
{
float y0{0}, y1{0};
bhalf2_t y{0};
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(y0)
: "v"(x));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(y1)
: "v"(x));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y)
: "v"(y0), "v"(y1));
return y;
}
// TODO: deprecate this
template <typename T>
struct inner_product_with_conversion
{
template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
{
const vector_type<X, N> a_vector{a};
const vector_type<X, N> b_vector{b};
T acc = 0;
static_for<0, N, 1>{}([&](auto i) {
acc += type_convert<T>(a_vector.Scalars()[i]) * type_convert<T>(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(float_t a, float_t b) const
{
return type_convert<T>(a) * type_convert<T>(b);
}
__device__ T operator()(int8x4_t a, int8x4_t b) const
{
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
T acc = 0;
static_for<0, 4, 1>{}([&](auto i) {
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x16_t a, int8x16_t b) const
{
const vector_type<int8_t, 16> a_vector{a};
const vector_type<int8_t, 16> b_vector{b};
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
};
template <typename T>
struct NumericLimits
{
......@@ -1016,4 +1123,17 @@ struct NumericLimits<half_t>
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
};
template <typename T>
struct TypeMap
{
using type = T;
};
#if defined(__gfx90a__)
template <>
struct TypeMap<ck::half_t>
{
using type = ck::bhalf_t;
};
#endif
} // namespace ck
......@@ -13,6 +13,182 @@ template <typename S,
typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct transpose_vectors;
template <typename Sx,
typename Sy,
index_t NX,
index_t NY,
typename enable_if<is_scalar_type<Sx>::value, bool>::type = false,
typename enable_if<is_scalar_type<Sy>::value, bool>::type = false>
struct transpose_convert_vectors;
__device__ void convert_half2_to_bhalf2(const half2_t& x, bhalf2_t& y)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx{x};
vector_type<bhalf_t, 2> vy;
float v0 = static_cast<float>(vx.template AsType<half_t>()[I0]);
float v1 = static_cast<float>(vx.template AsType<half_t>()[I1]);
vy.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
y = vy.template AsType<bhalf2_t>()[I0];
#else
float y0{0}, y1{0};
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(y0)
: "v"(x));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(y1)
: "v"(x));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y)
: "v"(y0), "v"(y1));
#endif
}
__device__ void
transpose_half_to_bhalf_2x2(const half2_t& x0, const half2_t& x1, bhalf2_t& y0, bhalf2_t& y1)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<bhalf_t, 2> vy0, vy1;
float v0 = static_cast<float>(vx0.template AsType<half_t>()[I0]);
float v1 = static_cast<float>(vx1.template AsType<half_t>()[I0]);
vy0.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy0.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
v0 = static_cast<float>(vx0.template AsType<half_t>()[I1]);
v1 = static_cast<float>(vx1.template AsType<half_t>()[I1]);
vy1.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy1.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
y0 = vy0.template AsType<bhalf2_t>()[I0];
y1 = vy1.template AsType<bhalf2_t>()[I0];
#else
float yv0{0}, yv1{0};
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(yv0)
: "v"(x0));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(yv1)
: "v"(x1));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y0)
: "v"(yv0), "v"(yv1));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(yv0)
: "v"(x0));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(yv1)
: "v"(x1));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y1)
: "v"(yv0), "v"(yv1));
#endif
}
template <index_t NX, index_t NY>
struct transpose_convert_vectors<half_t, half_t, NX, NY>
{
// we got [NY * NX] ammount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = half_t;
using VX = vector_type<half_t, s_per_x>;
using VY = vector_type<half_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 half2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
// reference to 2 half2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
// transpose
transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
template <index_t NX, index_t NY>
struct transpose_convert_vectors<half_t, bhalf_t, NX, NY>
{
// we got [NY * NX] ammount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = half_t;
using VX = vector_type<half_t, s_per_x>;
using VY = vector_type<bhalf_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 half2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
// reference to 2 half2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<bhalf2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<bhalf2_t>()(ix / I2);
// transpose
transpose_half_to_bhalf_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
// transpose fp16 2x2
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
{
......
......@@ -35,7 +35,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false;
else
os << delim;
os << static_cast<T>(v);
os << ck::type_convert<T>(v);
}
return os;
}
......
......@@ -77,8 +77,8 @@ ConvParams::ConvParams(ck::index_t n_dim,
conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
throw(
std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"));
}
}
......@@ -91,8 +91,8 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
throw(
std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"));
}
......@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t idx_eff =
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
out_spatial_len[i] =
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
conv_filter_strides[i] +
......
......@@ -45,3 +45,4 @@ add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd)
add_subdirectory(reduce)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(fp16_transfer_bf16)
\ No newline at end of file
add_test_executable(test_fp16_transfer_bf16 fp16_transfer_bf16.cpp)
target_link_libraries(test_fp16_transfer_bf16 PRIVATE host_tensor)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "check_err.hpp"
#include "transpose_vectors.hpp"
#include "common_header.hpp"
using SrcDataType = ck::half_t;
using DstDataType = ck::bhalf_t;
__global__ void gpu_convert_data(SrcDataType* in, DstDataType* out, int size)
{
using namespace ck;
ck::index_t num = blockIdx.x * blockDim.x + threadIdx.x * 2;
const auto src_buf = ck::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(in, size);
auto dst_buf = ck::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(out, size);
auto src_data = src_buf.template Get<ck::half2_t>(num, true);
ck::bhalf2_t dst_data;
convert_half2_to_bhalf2(src_data, dst_data);
dst_buf.template Set<ck::bhalf2_t>(num, true, dst_data);
}
__global__ void
gpu_transpose_convert_data(SrcDataType* in, DstDataType* out, const int size, const int stride)
{
using namespace ck;
ck::index_t num = blockIdx.x * blockDim.x + threadIdx.x * 2;
const auto src_buf = ck::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(in, size);
auto dst_buf = ck::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(out, size);
int x = num % stride;
int y = num / stride;
int num1 = (y + 1) * stride + x;
auto src_data0 = src_buf.template Get<ck::half2_t>(num, true);
auto src_data1 = src_buf.template Get<ck::half2_t>(num1, true);
ck::bhalf2_t dst_data0, dst_data1;
transpose_half_to_bhalf_2x2(src_data0, src_data1, dst_data0, dst_data1);
// rewrite
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<bhalf_t, 2> vx0{dst_data0}, vx1{dst_data1};
vector_type<bhalf_t, 2> vy0, vy1;
vy0.template AsType<bhalf_t>()(I0) = vx0.template AsType<bhalf_t>()[I0];
vy0.template AsType<bhalf_t>()(I1) = vx1.template AsType<bhalf_t>()[I0];
vy1.template AsType<bhalf_t>()(I0) = vx0.template AsType<bhalf_t>()[I1];
vy1.template AsType<bhalf_t>()(I1) = vx1.template AsType<bhalf_t>()[I1];
dst_buf.template Set<ck::bhalf2_t>(num, true, vy0.template AsType<ck::bhalf2_t>()[I0]);
dst_buf.template Set<ck::bhalf2_t>(num1, true, vy1.template AsType<ck::bhalf2_t>()[I0]);
}
void host_convert_data(SrcDataType* in, DstDataType* out, size_t len)
{
for(int i = 0; i < len; i++)
{
out[i] = ck::type_convert<ck::bhalf_t, ck::half_t>(in[i]);
}
}
int main(int, char*[])
{
bool pass = true;
constexpr int N = 4;
constexpr int K = 4;
constexpr int size = N * K;
constexpr int thread_num = size / 2;
// create tensor
Tensor<SrcDataType> src_n_k_host(
HostTensorDescriptor(std::vector<std::size_t>({N, K}), std::vector<std::size_t>({K, 1})));
Tensor<DstDataType> dst_n_k_host_result(
HostTensorDescriptor(std::vector<std::size_t>({N, K}), std::vector<std::size_t>({K, 1})));
Tensor<DstDataType> dst_n_k_device_result(
HostTensorDescriptor(std::vector<std::size_t>({N, K}), std::vector<std::size_t>({K, 1})));
// init data
src_n_k_host.GenerateTensorValue(GeneratorTensor_3<SrcDataType>{-5, 5});
dst_n_k_host_result.GenerateTensorValue(GeneratorTensor_1<DstDataType>{0});
dst_n_k_device_result.GenerateTensorValue(GeneratorTensor_1<DstDataType>{0});
// alloc gpu memory
DeviceMem in_dev_buf(sizeof(SrcDataType) * src_n_k_host.mDesc.GetElementSpace());
DeviceMem out_dev_buf(sizeof(DstDataType) * dst_n_k_host_result.mDesc.GetElementSpace());
// init gpu memory
in_dev_buf.ToDevice(src_n_k_host.mData.data());
out_dev_buf.SetZero();
// run cpu data convert
host_convert_data(src_n_k_host.mData.data(), dst_n_k_host_result.mData.data(), size);
// run kernel to convert data
gpu_convert_data<<<1, thread_num>>>(static_cast<SrcDataType*>(in_dev_buf.GetDeviceBuffer()),
static_cast<DstDataType*>(out_dev_buf.GetDeviceBuffer()),
src_n_k_host.mDesc.GetElementSpace());
// read from gpu
out_dev_buf.FromDevice(dst_n_k_device_result.mData.data());
pass = ck::utils::check_err(dst_n_k_device_result.mData, dst_n_k_host_result.mData);
// run kernel to tanspos and convert data
gpu_transpose_convert_data<<<1, thread_num / 2>>>(
static_cast<SrcDataType*>(in_dev_buf.GetDeviceBuffer()),
static_cast<DstDataType*>(out_dev_buf.GetDeviceBuffer()),
src_n_k_host.mDesc.GetElementSpace(),
K);
// read from gpu
out_dev_buf.FromDevice(dst_n_k_device_result.mData.data());
pass &= ck::utils::check_err(dst_n_k_device_result.mData, dst_n_k_host_result.mData);
#if 1
LogRangeAsType<float>(std::cout << "in : ", src_n_k_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out device: ", dst_n_k_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "out host: ", dst_n_k_host_result.mData, ",") << std::endl;
#endif
if(pass)
{
std::cout << "fp16 transfer to bf16: Pass" << std::endl;
return 0;
}
else
{
std::cout << "fp16 transfer to bf16: Fail" << std::endl;
return -1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment