"...community/pipeline_stable_diffusion_upscale_ldm3d.py" did not exist on "12a232efa99d7a8c33f54ae515c5a3d6fc5c8f34"
Commit c9133170 authored by danyao12's avatar danyao12
Browse files

support fwd fp16&bf16

parent 45b52f13
...@@ -5,10 +5,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,10 +5,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp) add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp) add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp) add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp) add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
......
...@@ -33,17 +33,20 @@ template <ck::index_t... Is> ...@@ -33,17 +33,20 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using DataType = BF16;
using B0DataType = F16; using GemmDataType = BF16;
using B1DataType = F16; using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = DataType;
using ZDataType = U16; using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -81,6 +84,7 @@ using DeviceGemmInstance = ...@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -139,7 +143,7 @@ using DeviceGemmInstance = ...@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, AccDataType,
...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp, B0ElementOp,
Acc0ElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out // Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
......
...@@ -33,17 +33,20 @@ template <ck::index_t... Is> ...@@ -33,17 +33,20 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using DataType = F16;
using B0DataType = F16; using GemmDataType = F16;
using B1DataType = F16; using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = DataType;
using ZDataType = U16; using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -72,7 +75,6 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -72,7 +75,6 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -82,6 +84,7 @@ using DeviceGemmInstance = ...@@ -82,6 +84,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -140,7 +143,7 @@ using DeviceGemmInstance = ...@@ -140,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, AccDataType,
...@@ -149,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -149,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp, B0ElementOp,
Acc0ElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out // Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
......
...@@ -158,6 +158,7 @@ template <index_t NumDimG, ...@@ -158,6 +158,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
......
...@@ -148,6 +148,7 @@ template <index_t NumDimG, ...@@ -148,6 +148,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename FloatAB, template <typename FloatAB,
typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB); sizeof(FloatGemm);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB); sizeof(FloatGemm);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc); sizeof(FloatGemmAcc);
...@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatGemm,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize()); a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
FloatAB, FloatGemm,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(b1_grid_desc_bk0_n_bk1), decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemm>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf // reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
...@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatGemm,
FloatGemmAcc, FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
...@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatGemm, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
...@@ -1100,7 +1102,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1100,7 +1102,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// workaround compiler issue; see ck/ck.hpp // workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 && (is_same_v<FloatGemm, bhalf_t>)&&MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128) Gemm1NPerBlock == 128)
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment