Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e09f6e02
Commit
e09f6e02
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactor
parent
57271814
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
24 deletions
+26
-24
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+13
-12
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+5
-4
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
e09f6e02
...
...
@@ -13,20 +13,20 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
InMemoryDataOperationEnum
...
DstInMemOps
>
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
struct
ThreadGroupTensorSliceTransfer_v7
{
static
constexpr
index_t
nDim
=
...
...
@@ -147,13 +147,13 @@ struct ThreadGroupTensorSliceTransfer_v7
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
DstInMemOps
...
>
;
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
e09f6e02
...
...
@@ -548,26 +548,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]);
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
// ThreadGroup
CDEElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
ThisThreadBlock
,
// ThreadGroup
Tuple
<
FloatCShuffle
,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>>
,
Tuple
<
FloatE
>
,
// typename DstData,
decltype
(
c_ds_descs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
Sequence
<
true
,
false
,
false
>
,
// bool ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>
,
// bool ThreadTransferDstResetCoordinateAfterRunFlags
EGlobalMemoryDataOperation
>
// DstInMemOp,
Sequence
<
false
>>
// bool ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_descs
,
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
View file @
e09f6e02
...
...
@@ -24,13 +24,13 @@ template <typename SrcDatas,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<...>
typename
DstResetCoordinateAfterRunFlags
,
// Sequence<...>
InMemoryDataOperationEnum
...
DstInMemOps
>
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
struct
ThreadwiseTensorSliceTransfer_v7
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -165,7 +165,8 @@ struct ThreadwiseTensorSliceTransfer_v7
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
dst_coords_
[
i
]);
constexpr
auto
DstInMemOp
=
make_tuple
(
DstInMemOps
...)[
i
];
constexpr
InMemoryDataOperationEnum
DstInMemOp
=
static_cast
<
InMemoryDataOperationEnum
>
(
DstInMemOps
::
At
(
i
.
value
));
dst_bufs
(
i
).
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coords_
[
i
].
GetOffset
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment