Commit e938bd61 authored by letaoqin's avatar letaoqin
Browse files

use a block thread cluster lengths

parent 29d5cbac
......@@ -391,6 +391,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
template <typename DataType>
struct TypeTransform
{
......@@ -444,29 +446,30 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
Sequence<I1, I1, I1, 4, 64, 1>,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3,
decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3),
Sequence<0, 1, 2, 4, 3, 5>,
Sequence<0, 1, 2, 4, 3, 5>,
5,
5,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3,
decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3),
Sequence<0, 1, 2, 4, 3, 5>,
Sequence<0, 1, 2, 4, 3, 5>,
5,
5,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment