Commit e938bd61 authored by letaoqin's avatar letaoqin
Browse files

use a block thread cluster lengths

parent 29d5cbac
......@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
template <typename DataType>
struct TypeTransform
{
......@@ -444,13 +446,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
Sequence<I1, I1, I1, 4, 64, 1>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment