Commit e938bd61 authored by letaoqin's avatar letaoqin
Browse files

use a block thread cluster lengths

parent 29d5cbac
...@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator struct D0Operator
{ {
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
{ {
...@@ -444,29 +446,30 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -444,29 +446,30 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 = static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2(); GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds = using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>, Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
Sequence<I1, I1, I1, 4, 64, 1>, typename sequence_merge<Sequence<1, 1, 1>,
Sequence<0, 1, 2, 4, 3, 5>, ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
typename TypeTransform<D0DataType>::Type, // SrcData Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // SrcData
D0GridDescriptor_M0_N0_N1_N2_M1_N3, typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3), D0GridDescriptor_M0_N0_N1_N2_M1_N3,
Sequence<0, 1, 2, 4, 3, 5>, decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3),
Sequence<0, 1, 2, 4, 3, 5>, Sequence<0, 1, 2, 4, 3, 5>,
5, Sequence<0, 1, 2, 4, 3, 5>,
5, 5,
ABlockTransferSrcScalarPerVector, 5,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferSrcScalarPerVector,
1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
true, // SrcResetCoord 1,
true, // DstResetCoord true, // SrcResetCoord
NumGemmKPrefetchStage>; true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment