Commit c9133170 authored by danyao12's avatar danyao12
Browse files

support fwd fp16&bf16

parent 45b52f13
......@@ -5,10 +5,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
......
......@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using DataType = BF16;
using GemmDataType = BF16;
using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using CDataType = DataType;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
......@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
......@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
// Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
// Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
......
......@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using DataType = F16;
using GemmDataType = F16;
using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using CDataType = DataType;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
......@@ -72,7 +75,6 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -82,6 +84,7 @@ using DeviceGemmInstance =
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -140,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
......@@ -149,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
// Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
// Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
......
......@@ -158,6 +158,7 @@ template <index_t NumDimG,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
......
......@@ -148,6 +148,7 @@ template <index_t NumDimG,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename FloatAB,
typename FloatGemm,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
......@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
......@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB);
sizeof(FloatGemm);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
sizeof(FloatGemm);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
......@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatGemm,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatGemm,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
FloatGemm,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
FloatAB,
FloatGemm,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
......@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatGemm,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
......@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemm>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
......@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
FloatGemm,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
......@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack * XdlopsGemm<FloatGemm, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......@@ -1100,7 +1102,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
(is_same_v<FloatGemm, bhalf_t>)&&MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment