Commit 326e9ee8 authored by danyao12's avatar danyao12
Browse files

support bwd fp16&bf16

parent 32b03f33
......@@ -10,7 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
......
......@@ -50,9 +50,10 @@ Kernel outputs:
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -748,9 +749,12 @@ int run(int argc, char* argv[])
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
}
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
});
#if PRINT_HOST
{
......
......@@ -49,9 +49,10 @@ Kernel outputs:
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using DataType = BF16;
using GemmDataType = BF16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -101,6 +103,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -169,6 +172,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -340,16 +344,21 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512
ck::index_t N = 512; // 512
ck::index_t K = 64;
ck::index_t O = 64;
ck::index_t M = 512; // 512
ck::index_t N = 512; // 512
#if USING_HD32
ck::index_t K = 32; // K/O<=32
ck::index_t O = 32;
#else
ck::index_t K = 64; // 32<K/O<=64
ck::index_t O = 64;
#endif
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
float alpha = 1.f / std::sqrt(K);
bool input_permute = true; // false;
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
......@@ -747,9 +756,12 @@ int run(int argc, char* argv[])
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
}
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
});
#if PRINT_HOST
{
......
......@@ -173,6 +173,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename DataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
......@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
......@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
......@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0MWaves;
static constexpr index_t GemmNWave = Gemm0NWaves;
......@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
......@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
......@@ -733,12 +735,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMLoop = Free1_M / Sum_M;
static constexpr index_t GemmMPack =
math::max(A_M1, MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(A_M1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_M3 = GemmMPack; // 8
static constexpr index_t B_M2 =
XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4
static constexpr index_t B_M0 = GemmMLoop; // 2
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4
static constexpr index_t B_M0 = GemmMLoop; // 2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1_M2_N2()
{
......@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp,
......@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_m0_o_m1 = MakeBThreadDesc_M0_O_M1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_o0_o1_o2_m0_m1_m2_m3),
decltype(b_thread_desc_o0_o1_o2_m0_m1_m2_m3),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3,
......@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_m0_n_m1),
decltype(b_thread_desc_m0_o_m1),
......@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Gemm2Params_N_O_M::GemmMPack,
true, // TransposeC
Gemm2Params_N_O_M::GemmMPack *
XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{}
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{}
.K0PerXdlops,
Gemm2Params_N_O_M::GemmMPack>;
......@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
......@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto p_slash_sgrad_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
......@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto reduction_space_offset =
(ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) *
sizeof(DataType) / sizeof(FloatGemmAcc);
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{
const index_t k_bytes_end =
(SharedMemTrait::k_block_space_offset + SharedMemTrait::k_block_space_size_aligned) *
sizeof(DataType);
sizeof(GemmDataType);
const index_t v_bytes_end =
(SharedMemTrait::v_block_space_offset + SharedMemTrait::v_block_space_size_aligned) *
sizeof(DataType);
sizeof(GemmDataType);
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(DataType);
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
......@@ -1263,8 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const float p_drop,
ck::philox& ph)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
......@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// LDS allocation for Q / K / V / dY
auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize());
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::v_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::v_block_space_offset,
GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize());
// Q matrix blockwise copy
......@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors
......@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3.GetElementSpaceSize());
// dV: transform input and output tensor descriptors
......@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
......@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
GemmDataType,
FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
......
......@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return vy.template AsType<double2_t>()[I0];
}
inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t& b)
{
half2_t rtn;
rtn[0] = a[0] + b[0];
rtn[1] = a[1] + b[1];
return rtn;
}
union U32FP162_ADDR
{
uint32_t* u32_a;
half2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
half2_t fp162;
};
template <>
__device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
{
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b)
{
return type_convert<bhalf_t>(type_convert<float>(a) + type_convert<float>(b));
}
inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2_t& b)
{
bhalf2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
};
template <>
__device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
{
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment