Commit 326e9ee8 authored by danyao12's avatar danyao12
Browse files

support bwd fp16&bf16

parent 32b03f33
...@@ -10,7 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_ ...@@ -10,7 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp) add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp) add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp) add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
......
...@@ -51,6 +51,7 @@ template <ck::index_t... Is> ...@@ -51,6 +51,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
...@@ -748,9 +749,12 @@ int run(int argc, char* argv[]) ...@@ -748,9 +749,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
......
...@@ -50,6 +50,7 @@ template <ck::index_t... Is> ...@@ -50,6 +50,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
...@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using DataType = BF16;
using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -101,6 +103,7 @@ using DeviceGemmInstance = ...@@ -101,6 +103,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -169,6 +172,7 @@ using DeviceGemmInstance = ...@@ -169,6 +172,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -342,14 +346,19 @@ int run(int argc, char* argv[]) ...@@ -342,14 +346,19 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512 ck::index_t M = 512; // 512
ck::index_t N = 512; // 512 ck::index_t N = 512; // 512
ck::index_t K = 64; #if USING_HD32
ck::index_t K = 32; // K/O<=32
ck::index_t O = 32;
#else
ck::index_t K = 64; // 32<K/O<=64
ck::index_t O = 64; ck::index_t O = 64;
#endif
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 6; // 16
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = true; // false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
...@@ -747,9 +756,12 @@ int run(int argc, char* argv[]) ...@@ -747,9 +756,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
......
...@@ -173,6 +173,7 @@ template <index_t NumDimG, ...@@ -173,6 +173,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename DataType, template <typename DataType,
typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatLSE, typename FloatLSE,
...@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1), decltype(q_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1), decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1), decltype(v_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1), decltype(ygrad_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
static constexpr index_t KPack = math::max( static constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
DataType, GemmDataType,
decltype(a_src_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0MWaves; static constexpr index_t GemmMWave = Gemm0MWaves;
static constexpr index_t GemmNWave = Gemm0NWaves; static constexpr index_t GemmNWave = Gemm0NWaves;
...@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1(); static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy = using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType, ThreadwiseTensorSliceTransfer_v2<GemmDataType,
DataType, GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3, BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
...@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1), decltype(b_thread_desc_k0_n_k1),
...@@ -733,10 +735,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -733,10 +735,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMLoop = Free1_M / Sum_M; static constexpr index_t GemmMLoop = Free1_M / Sum_M;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(A_M1, MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(A_M1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_M3 = GemmMPack; // 8 static constexpr index_t B_M3 = GemmMPack; // 8
static constexpr index_t B_M2 = static constexpr index_t B_M2 =
XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2 XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4 static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4
static constexpr index_t B_M0 = GemmMLoop; // 2 static constexpr index_t B_M0 = GemmMLoop; // 2
...@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough> template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp, ElementwiseOp,
...@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_m0_o_m1 = MakeBThreadDesc_M0_O_M1(); static constexpr auto b_thread_desc_m0_o_m1 = MakeBThreadDesc_M0_O_M1();
using BBlockwiseCopy = using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType, ThreadwiseTensorSliceTransfer_v2<GemmDataType,
DataType, GemmDataType,
decltype(b_block_desc_o0_o1_o2_m0_m1_m2_m3), decltype(b_block_desc_o0_o1_o2_m0_m1_m2_m3),
decltype(b_thread_desc_o0_o1_o2_m0_m1_m2_m3), decltype(b_thread_desc_o0_o1_o2_m0_m1_m2_m3),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3, BThreadSlice_O0_O1_O2_M0_M1_M2_M3,
...@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_m0_n_m1), decltype(a_block_desc_m0_n_m1),
decltype(b_thread_desc_m0_o_m1), decltype(b_thread_desc_m0_o_m1),
...@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Gemm2Params_N_O_M::GemmMPack, Gemm2Params_N_O_M::GemmMPack,
true, // TransposeC true, // TransposeC
Gemm2Params_N_O_M::GemmMPack * Gemm2Params_N_O_M::GemmMPack *
XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{} XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{}
.K0PerXdlops, .K0PerXdlops,
Gemm2Params_N_O_M::GemmMPack>; Gemm2Params_N_O_M::GemmMPack>;
...@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType, FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
true>; true>;
...@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto p_slash_sgrad_block_desc_m0_n_m1 = static constexpr auto p_slash_sgrad_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>(); GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned = static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
...@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto reduction_space_offset = static constexpr auto reduction_space_offset =
(ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) * (ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) *
sizeof(DataType) / sizeof(FloatGemmAcc); sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{ {
const index_t k_bytes_end = const index_t k_bytes_end =
(SharedMemTrait::k_block_space_offset + SharedMemTrait::k_block_space_size_aligned) * (SharedMemTrait::k_block_space_offset + SharedMemTrait::k_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t v_bytes_end = const index_t v_bytes_end =
(SharedMemTrait::v_block_space_offset + SharedMemTrait::v_block_space_size_aligned) * (SharedMemTrait::v_block_space_offset + SharedMemTrait::v_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t p_slash_sgrad_bytes_end = const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset + (SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) * SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc); sizeof(FloatGemmAcc);
...@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// LDS allocation for Q / K / V / dY // LDS allocation for Q / K / V / dY
auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::q_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize()); GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize());
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::k_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize()); GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::v_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::v_block_space_offset,
GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize()); GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize()); GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize());
// Q matrix blockwise copy // Q matrix blockwise copy
...@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and B // Gemm1: VGPR allocation for A and B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize()); Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize()); Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors // dQ: transform input and output tensor descriptors
...@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize()); Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3.GetElementSpaceSize()); Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3.GetElementSpaceSize());
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
...@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for y // performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, DataType,
DataType, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
...@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for ygrad // performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, GemmDataType,
DataType, FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o), decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o), decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()), decltype(ygrad_thread_desc_m_o.GetLengths()),
......
...@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x) ...@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return vy.template AsType<double2_t>()[I0]; return vy.template AsType<double2_t>()[I0];
} }
inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t& b)
{
half2_t rtn;
rtn[0] = a[0] + b[0];
rtn[1] = a[1] + b[1];
return rtn;
}
union U32FP162_ADDR
{
uint32_t* u32_a;
half2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
half2_t fp162;
};
template <>
__device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
{
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b)
{
return type_convert<bhalf_t>(type_convert<float>(a) + type_convert<float>(b));
}
inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2_t& b)
{
bhalf2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
};
template <>
__device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
{
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for // instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment