"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "28da8b84886d409fc1d536e0bf5b6ac557b174dd"
Commit 9dc05eaa authored by ltqin's avatar ltqin
Browse files

change fp16 xdl to bf16

parent c1ed00b6
...@@ -85,6 +85,21 @@ template <typename DataType, ...@@ -85,6 +85,21 @@ template <typename DataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
template <typename T>
struct TypeMap
{
using type = T;
};
#if defined(__gfx90a__)
template <>
struct TypeMap<ck::half_t>
{
using type = ck::bhalf_t;
};
#endif
using LDSDataType = typename TypeMap<DataType>::type;
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -126,7 +141,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -126,7 +141,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -142,7 +157,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -142,7 +157,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{ {
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -456,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -456,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -481,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -481,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -496,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -496,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max( static constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -564,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -564,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
DataType, LDSDataType,
decltype(a_src_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -583,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -583,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -614,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -614,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -634,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -634,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
GemmKPack, GemmKPack,
true, // TransposeC true, // TransposeC
GemmKPack, // AMmaKStride GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmKPack, false>{} GemmKPack * XdlopsGemm<LDSDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>; .K0PerXdlops /* BMmaKStride */>;
}; };
...@@ -666,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -666,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(math::lcm(A_M1, B_M1), math::max(math::lcm(A_M1, B_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>; using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using BThreadClusterLengths = using BThreadClusterLengths =
...@@ -791,7 +807,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -791,7 +807,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough> template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, LDSDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp, ElementwiseOp,
...@@ -821,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -821,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BThreadClusterLengths, typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_M0_O_M1, GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
...@@ -838,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -838,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_m0_n_m1), decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
...@@ -905,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -905,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(FloatGemmAcc);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -917,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -917,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType, FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
true>; true>;
...@@ -1079,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1079,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr auto b2_block_desc_m0_o_m1 = static constexpr auto b2_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>(); GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(LDSDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -1115,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1115,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned + const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) * SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
...@@ -1224,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1224,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm0: LDS allocation for A and B: be careful of alignment // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize()); Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline // Gemm0: gridwise GEMM pipeline
...@@ -1320,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1320,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and LDS allocation for B // Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, LDSDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize()); Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors // dQ: transform input and output tensor descriptors
...@@ -1516,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1516,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a2_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize()); Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b2_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize()); Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
...@@ -1640,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1640,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, DataType,
DataType, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
......
...@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert fp16 to bf16
template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
union
{
float fp32;
uint32_t int32;
} u = {static_cast<float>(x)};
return uint16_t(u.int32 >> 16);
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, half2_t>(half2_t x)
{
float y0{0}, y1{0};
bhalf2_t y{0};
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(y0)
: "v"(x));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(y1)
: "v"(x));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y)
: "v"(y0), "v"(y1));
return y;
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment