Commit e938bd61 authored by letaoqin's avatar letaoqin
Browse files

use a block thread cluster lengths

parent 29d5cbac
...@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator struct D0Operator
{ {
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
{ {
...@@ -444,13 +446,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -444,13 +446,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 = static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2(); GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds = using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>, Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
Sequence<I1, I1, I1, 4, 64, 1>, typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>, Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment