Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
57271814
Commit
57271814
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactor
parent
f9b92b1e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
117 additions
and
212 deletions
+117
-212
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+77
-101
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+32
-87
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+8
-24
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
57271814
...
@@ -14,60 +14,62 @@ namespace ck {
...
@@ -14,60 +14,62 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
SrcDatas
,
typename
Src1Data
,
typename
DstDatas
,
typename
Src2Data
,
typename
SrcDescs
,
typename
DstData
,
typename
DstDescs
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
Src2Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
InMemoryDataOperationEnum
...
DstInMemOps
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v7
struct
ThreadGroupTensorSliceTransfer_v7
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
nDim
=
static
constexpr
auto
I1
=
Number
<
1
>
{};
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7
(
const
Src0Desc
&
src0_desc
,
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Src2Desc
&
src2_desc
,
const
Index
&
src2_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
make_tuple
(
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
()),
tie
(
dst_desc
),
make_tuple
(
make_zero_multi_index
<
nDim
>
()),
element_op
)
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
{
static_assert
(
nDim
==
remove_cvref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nDim
==
remove_cvref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDim
==
remove_cvref_t
<
Src2Desc
>::
GetNumOfDimension
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
nDim
==
ThreadClusterLengths
::
Size
()
&&
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
"wrong! nDim not consistent"
);
...
@@ -87,73 +89,51 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -87,73 +89,51 @@ struct ThreadGroupTensorSliceTransfer_v7
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
const
auto
src_thread_slice_origins
=
generate_tuple
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
make_tuple
(
src0_block_slice_origin
+
thread_data_idx_begin
,
Number
<
nSrc
>
{});
src1_block_slice_origin
+
thread_data_idx_begin
,
src2_block_slice_origin
+
thread_data_idx_begin
));
threadwise_transfer_
.
SetDstSliceOrigin
(
const
auto
dst_thread_slice_origins
=
generate_tuple
(
tie
(
dst_desc
),
make_tuple
(
dst_block_slice_origin
+
thread_data_idx_begin
));
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
}
Number
<
nDst
>
{});
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
Src2Desc
&
src2_desc
,
const
Src2Buffer
&
src2_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
tie
(
src0_buf
,
src1_buf
,
src2_buf
),
tie
(
dst_desc
),
tie
(
dst_buf
));
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
{
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
tie
(
src0_desc
,
Src1Desc
{},
Src2Desc
{}),
step
,
I0
);
}
}
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
Run
(
src_descs
,
src_bufs
,
dst_descs
,
dst_bufs
);
tie
(
Src0Desc
{},
src1_desc
,
Src2Desc
{}),
step
,
I1
);
}
}
}
}
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
tie
(
Src0Desc
{},
Src1Desc
{},
src2_desc
),
step
,
I2
);
}
}
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
tie
(
dst_desc
)
,
step
,
I0
);
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
s
,
iDst
,
step
);
}
}
}
}
...
@@ -161,23 +141,19 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -161,23 +141,19 @@ struct ThreadGroupTensorSliceTransfer_v7
static
constexpr
auto
thread_cluster_desc_
=
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7
<
using
ThreadwiseTransfer
=
Tuple
<
remove_cvref_t
<
Src0Data
>
,
remove_cvref_t
<
Src1Data
>
,
remove_cvref_t
<
Src2Data
>>
,
ThreadwiseTensorSliceTransfer_v7
<
SrcDatas
,
Tuple
<
remove_cvref_t
<
DstData
>>
,
DstDatas
,
Tuple
<
remove_reference_t
<
Src0Desc
>&
,
SrcDescs
,
remove_reference_t
<
Src1Desc
>&
,
DstDescs
,
remove_reference_t
<
Src2Desc
>&>
,
ElementwiseOperation
,
Tuple
<
remove_reference_t
<
DstDesc
>&>
,
decltype
(
thread_slice_lengths
),
ElementwiseOperation
,
DimAccessOrder
,
decltype
(
thread_slice_lengths
),
VectorDim
,
DimAccessOrder
,
ScalarPerVector
,
VectorDim
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ScalarPerVector
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
Sequence
<
ThreadTransferSrc0ResetCoordinateAfterRun
,
DstInMemOps
...
>
;
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc2ResetCoordinateAfterRun
>
,
Sequence
<
ThreadTransferDstResetCoordinateAfterRun
>
,
DstInMemOp
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
57271814
...
@@ -542,77 +542,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -542,77 +542,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
#if 0
// FIXME: arbitrary # of D tensors
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
const
auto
c_ds_descs
=
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
ThisThreadBlock, // ThreadGroup
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
CDEElementwiseOperation, // ElementwiseOperation,
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]);
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
cde_element_op};
#else
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
// ThreadGroup
ThisThreadBlock
,
// ThreadGroup
CDEElementwiseOperation
,
// ElementwiseOperation,
CDEElementwiseOperation
,
// ElementwiseOperation,
EGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename Src0Data,
Tuple
<
FloatCShuffle
,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
// typename Src1Data,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>
,
// typename Src2Data,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>>
,
FloatE
,
// typename DstData,
Tuple
<
FloatE
>
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_ds_descs
),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
Sequence
<
true
,
false
,
false
>
,
// bool ThreadTransferSrcResetCoordinateAfterRunFlags
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
Sequence
<
false
>
,
// bool ThreadTransferDstResetCoordinateAfterRunFlags
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
EGlobalMemoryDataOperation
>
// DstInMemOp,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_ds_descs
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
cde_element_op
};
cde_element_op
};
#endif
// space filling curve for threadwise C in VGPR before shuffle
// space filling curve for threadwise C in VGPR before shuffle
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
...
@@ -655,42 +617,25 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -655,42 +617,25 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
block_sync_lds
();
block_sync_lds
();
// each block copy its data from LDS to global
// each block copy its data from LDS to global
#if 1
cde_block_copy_lds_and_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
ds_grid_buf
[
I0
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_buf
[
I1
],
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
#else
cde_block_copy_lds_and_global
.
Run
(
cde_block_copy_lds_and_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_ds_descs
,
c_shuffle_block_buf
,
tie
(
c_shuffle_block_buf
,
ds_grid_buf
[
I0
],
ds_grid_buf
[
I1
]),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
ds_grid_buf
[
I0
],
tie
(
e_grid_buf
));
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_buf
[
I1
],
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
#endif
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c
_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
constexpr
auto
e
_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
// move on Ds
cde_block_copy_lds_and_global
.
MoveSrc1SliceWindow
(
static_for
<
0
,
DsDataType
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
c_global_step
);
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_descs
,
i
+
I1
,
e_global_step
);
cde_block_copy_lds_and_global
.
MoveSrc2SliceWindow
(
});
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
c_global_step
);
// move on E
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
e_grid_desc_mblock_mperblock_nblock_nperblock
,
c
_global_step
);
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)
,
I0
,
e
_global_step
);
}
}
});
});
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
View file @
57271814
...
@@ -80,8 +80,8 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -80,8 +80,8 @@ struct ThreadwiseTensorSliceTransfer_v7
}
}
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetSrcSliceOrigin
(
const
SrcDescs
&
src_descs
,
__device__
void
SetSrcSliceOrigin
s
(
const
SrcDescs
&
src_descs
,
const
Indices
&
src_slice_origin_idxs
)
const
Indices
&
src_slice_origin_idxs
)
{
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
...
@@ -89,8 +89,8 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -89,8 +89,8 @@ struct ThreadwiseTensorSliceTransfer_v7
}
}
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetDstSliceOrigin
(
const
DstDescs
&
dst_descs
,
__device__
void
SetDstSliceOrigin
s
(
const
DstDescs
&
dst_descs
,
const
Indices
&
dst_slice_origin_idxs
)
const
Indices
&
dst_slice_origin_idxs
)
{
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
...
@@ -234,8 +234,8 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -234,8 +234,8 @@ struct ThreadwiseTensorSliceTransfer_v7
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_step_idx
,
Number
<
ISrc
>
iSrc
,
Number
<
ISrc
>
iSrc
)
const
Index
&
src_slice_origin_step_idx
)
{
{
// if src coord was not reset by RunRead(), then need to adjust the step here
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
...
@@ -251,8 +251,8 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -251,8 +251,8 @@ struct ThreadwiseTensorSliceTransfer_v7
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
dst_slice_origin_step_idx
,
Number
<
IDst
>
iDst
,
Number
<
IDst
>
iDst
)
const
Index
&
dst_slice_origin_step_idx
)
{
{
// if dst coord was not reset by Run(), then need to adjust the step here
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
...
@@ -265,22 +265,6 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -265,22 +265,6 @@ struct ThreadwiseTensorSliceTransfer_v7
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveAllSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_step_idx
)
{
static_for
<
0
,
nSrc
,
1
>
{}(
[
&
](
auto
i
)
{
MoveSrcSliceWindow
(
src_descs
,
src_slice_origin_step_idx
,
i
);
});
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveAllDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
dst_slice_origin_step_idx
)
{
static_for
<
0
,
nDst
,
1
>
{}(
[
&
](
auto
i
)
{
MoveDstSliceWindow
(
dst_descs
,
dst_slice_origin_step_idx
,
i
);
});
}
private:
private:
SrcCoords
src_coords_
;
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
DstCoords
dst_coords_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment