"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "93b671cfdc184ca283c4feacb8db996e6ff1ea51"
Unverified Commit ac3ef99c authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1010 from ROCmSoftwarePlatform/mha-train-develop-bias-shfl

Add bias with shuffle for flash attention fwd
parents acea1753 e87ddb0e
...@@ -101,7 +101,7 @@ using DeviceGemmInstance = ...@@ -101,7 +101,7 @@ using DeviceGemmInstance =
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 4, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
...@@ -121,13 +121,13 @@ using DeviceGemmInstance = ...@@ -121,13 +121,13 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
4, 8,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
DIM / 32,
4, 4,
2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
......
...@@ -120,7 +120,7 @@ using DeviceGemmInstance = ...@@ -120,7 +120,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
1, 8,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
......
...@@ -35,7 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename D0GridDescriptor_M0_N0_N1_N2_M1_N3,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
...@@ -60,8 +60,7 @@ __global__ void ...@@ -60,8 +60,7 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio, const index_t h_ratio,
...@@ -109,7 +108,7 @@ __global__ void ...@@ -109,7 +108,7 @@ __global__ void
c1de_element_op, c1de_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_m1_m2_n1_m3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map, block_2_ctile_map,
...@@ -129,7 +128,7 @@ __global__ void ...@@ -129,7 +128,7 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio; ignore = h_ratio;
...@@ -206,8 +205,7 @@ template <index_t NumDimG, ...@@ -206,8 +205,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
int D0sTransferSrcScalarPerVector = 4, LoopScheduler LoopSched = LoopScheduler::Default>
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG, : public DeviceBatchedMultiheadAttentionInfer<NumDimG,
NumDimM, NumDimM,
...@@ -537,9 +535,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -537,9 +535,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
{ {
D0GridDesc_M_N d0_grid_desc_m_n_ = MakeD0GridDescriptor_M_N( D0GridDesc_M_N d0_grid_desc_m_n_ = MakeD0GridDescriptor_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(d0_grid_desc_m_n_);
d0_grid_desc_m_n_);
d0_grid_desc_g_m_n_ = MakeD0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, d0_grid_desc_g_m_n_ = MakeD0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
...@@ -592,8 +589,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -592,8 +589,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_; c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -660,7 +656,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -660,7 +656,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
...@@ -685,7 +681,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -685,7 +681,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_, arg.h_ratio_,
......
...@@ -112,7 +112,7 @@ __global__ void ...@@ -112,7 +112,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
...@@ -465,8 +465,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -465,8 +465,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -576,9 +575,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -576,9 +575,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N( const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)}; tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(d0_grid_desc_m_n);
d0_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
...@@ -628,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -628,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
p_c_grid, p_c_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_m1_m2_n1_m3_,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......
...@@ -92,11 +92,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -92,11 +92,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0); static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0); static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -366,6 +361,125 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -366,6 +361,125 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = AK0;
static constexpr auto D0N0 = Number<NPerBlock / KPerBlock>{};
static_assert(NPerBlock % KPerBlock == 0);
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_n1_n2_m1_n3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, D0N0, D0N1, D0N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3, 5>{}));
return d0_grid_desc_m0_n0_n1_n2_m1_n3;
}
using D0GridDescriptor_M0_N0_N1_N2_M1_N3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(D0GridDesc_M_N{}))>;
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
static_assert(ABlockTransferDstScalarPerVector_AK1 % D0BlockTransferSrcScalarPerVector ==
0);
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
};
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
{
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0N1, Number<MPerBlock>{}, D0N2));
}
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2()
{
constexpr auto d0_raw_n0_m_n1 =
make_naive_tensor_descriptor_packed(make_tuple(D0N1, Number<MPerBlock>{}, D0N2));
constexpr auto d0_raw_m_n = transform_tensor_descriptor(
d0_raw_n0_m_n1,
make_tuple(make_pass_through_transform(Number<MPerBlock>{}),
make_merge_transform(make_tuple(D0N1, D0N2))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto d0_m0_m1_n0_n1_n2 = transform_tensor_descriptor(
d0_raw_m_n,
make_tuple(make_unmerge_transform(
make_tuple(Number<MPerBlock / MPerXdl>{}, Number<MPerXdl>{})),
make_unmerge_transform(make_tuple((D0N1 * D0N2) / (I2 * I4), I2, I4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}));
return d0_m0_m1_n0_n1_n2;
}
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4));
static constexpr auto d0_block_dst_desc_m0_n0_n1_n2_m1_n3 =
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3();
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3,
decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3),
Sequence<0, 1, 2, 4, 3, 5>,
Sequence<0, 1, 2, 4, 3, 5>,
5,
5,
D0BlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_m0_m1_n0_n1_n2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
};
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -403,26 +517,26 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -403,26 +517,26 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask> template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
const FloatAB* __restrict__ p_b_grid, Run(const FloatAB* __restrict__ p_a_grid,
const D0DataType* __restrict__ p_d0_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const D0DataType* __restrict__ p_d0_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b1_grid,
void* __restrict__ p_shared, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const AccElementwiseOperation& acc_element_op, const BElementwiseOperation& b_element_op,
const B1ElementwiseOperation& b1_element_op, const AccElementwiseOperation& acc_element_op,
const CElementwiseOperation& c_element_op, const B1ElementwiseOperation& b1_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const CElementwiseOperation& c_element_op,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const D0GridDescriptor_M0_N0_N1_N2_M1_N3& d0_grid_desc_m0_n0_n1_n2_m1_n3,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask) const C0MatrixMask& c0_matrix_mask)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -683,49 +797,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -683,49 +797,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d0 matrix)
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // RegisterNum
auto d0_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<D0DataType,
D0DataType,
decltype(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
D0BlockTransferSrcScalarPerVector,
1,
false>(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack // selected_mfma.k_per_blk <= Gemm1KPack
...@@ -827,6 +898,17 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -827,6 +898,17 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest(); running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// gemm1 K loop // gemm1 K loop
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(block_work_idx[I0], 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
...@@ -920,30 +1002,53 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -920,30 +1002,53 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( if(p_d0_grid != nullptr)
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); {
// get register static constexpr auto& c_thread_desc = blockwise_gemm.GetCThreadDesc();
StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType, const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), p_d0_grid, d0_grid_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
true>
d0_thread_buf; auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
// load data from global D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
d0_threadwise_copy.Run(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_buf, auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, D0Operator::d0_thread_desc_.GetElementSpaceSize());
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0_thread_buf); static_for<0, D0N0, 1>{}([&](auto nr) {
// load data to lds
// acc add bias d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_n1_n2_m1_n3,
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { d0_grid_buf);
acc_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
d0_threadwise_copy.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_n1_n2_m1_n3, make_multi_index(0, 0, 1, 0, 0, 0));
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(I0, nr, i));
acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 1, -D0N0.value, 0, 0, 0));
}
} }
// softmax // softmax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment