Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f9b92b1e
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "d4cef21b646431da52eb5257ede499b5ae50b203"
Commit
f9b92b1e
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactor
parent
ff4f8ba8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
81 deletions
+102
-81
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+35
-29
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+8
-9
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+50
-34
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+8
-9
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
f9b92b1e
#pragma once
#pragma once
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
...
@@ -34,6 +35,10 @@ template <typename ThreadGroup,
...
@@ -34,6 +35,10 @@ template <typename ThreadGroup,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v7
struct
ThreadGroupTensorSliceTransfer_v7
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
...
@@ -106,11 +111,10 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -106,11 +111,10 @@ struct ThreadGroupTensorSliceTransfer_v7
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
Run
(
threadwise_transfer_
.
Run
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
tie
(
src0_buf
,
src1_buf
,
src2_buf
),
tie
(
src0_buf
,
src1_buf
,
src2_buf
),
tie
(
dst_desc
),
tie
(
dst_desc
),
tie
(
dst_buf
));
tie
(
dst_buf
));
}
}
}
}
...
@@ -119,7 +123,8 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -119,7 +123,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
threadwise_transfer_
.
MoveSrcSliceWindow
(
tie
(
src0_desc
,
Src1Desc
{},
Src2Desc
{}),
step
,
I0
);
}
}
}
}
...
@@ -128,7 +133,8 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -128,7 +133,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
threadwise_transfer_
.
MoveSrcSliceWindow
(
tie
(
Src0Desc
{},
src1_desc
,
Src2Desc
{}),
step
,
I1
);
}
}
}
}
...
@@ -137,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -137,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
threadwise_transfer_
.
MoveSrcSliceWindow
(
tie
(
Src0Desc
{},
Src1Desc
{},
src2_desc
),
step
,
I2
);
}
}
}
}
...
@@ -146,7 +153,7 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -146,7 +153,7 @@ struct ThreadGroupTensorSliceTransfer_v7
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
threadwise_transfer_
.
MoveDstSliceWindow
(
tie
(
dst_desc
)
,
step
,
I0
);
}
}
}
}
...
@@ -154,26 +161,25 @@ struct ThreadGroupTensorSliceTransfer_v7
...
@@ -154,26 +161,25 @@ struct ThreadGroupTensorSliceTransfer_v7
static
constexpr
auto
thread_cluster_desc_
=
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7
<
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
remove_cvref_t
<
Src0Data
>
,
remove_cvref_t
<
Src1Data
>
,
remove_cvref_t
<
Src2Data
>>
,
Tuple
<
remove_cvref_t
<
Src0Data
>
,
remove_cvref_t
<
Src1Data
>
,
remove_cvref_t
<
Src2Data
>>
,
Tuple
<
remove_cvref_t
<
DstData
>>
,
Tuple
<
remove_cvref_t
<
DstData
>>
,
Tuple
<
remove_reference_t
<
Src0Desc
>&
,
Tuple
<
remove_reference_t
<
Src0Desc
>&
,
remove_reference_t
<
Src1Desc
>&
,
remove_reference_t
<
Src1Desc
>&
,
remove_reference_t
<
Src2Desc
>&>
,
remove_reference_t
<
Src2Desc
>&>
,
Tuple
<
remove_reference_t
<
DstDesc
>&>
,
Tuple
<
remove_reference_t
<
DstDesc
>&>
,
ElementwiseOperation
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
DimAccessOrder
,
VectorDim
,
VectorDim
,
ScalarPerVector
,
ScalarPerVector
,
Sequence
<
ThreadTransferSrc0ResetCoordinateAfterRun
,
Sequence
<
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc2ResetCoordinateAfterRun
>
,
ThreadTransferSrc2ResetCoordinateAfterRun
>
,
Sequence
<
ThreadTransferDstResetCoordinateAfterRun
>
,
Sequence
<
ThreadTransferDstResetCoordinateAfterRun
>
,
DstInMemOp
>
;
DstInMemOp
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
f9b92b1e
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -124,7 +123,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -124,7 +123,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
decltype
(
DsDataType
{}.
At
(
i
))
>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>
>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
},
...
@@ -543,8 +542,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -543,8 +542,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
#if
1
#if
0
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v
6r3
<
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v
7
<
ThisThreadBlock, // ThreadGroup
ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation,
CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
EGlobalMemoryDataOperation, // DstInMemOp,
...
@@ -588,11 +587,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -588,11 +587,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename Src0Data,
FloatCShuffle
,
// typename Src0Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I0
])
>
,
// typename Src1Data,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>
>
,
// typename Src1Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I1
])
>
,
// typename Src2Data,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>
>
,
// typename Src2Data,
FloatE
,
// typename DstData,
FloatE
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
View file @
f9b92b1e
...
@@ -97,10 +97,10 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -97,10 +97,10 @@ struct ThreadwiseTensorSliceTransfer_v7
});
});
}
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
...
@@ -109,16 +109,18 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -109,16 +109,18 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
const
DstDescs
&
dst_descs
,
const
DstBuffers
&
dst_bufs
)
DstBuffers
dst_bufs
)
{
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
constexpr
index_t
num
=
data_types
.
Size
();
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
([
&
](
auto
i
)
{
return
generate_tuple
(
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
[
&
](
auto
i
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
},
Number
<
num
>
{});
},
Number
<
num
>
{});
};
};
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
...
@@ -130,7 +132,7 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -130,7 +132,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from src_bufs into src_vectors
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
remove_cvref_t
<
typename
decltype
(
src_vectors
[
i
])
::
type
>
;
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>
::
type
;
const
bool
is_src_valid
=
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
...
@@ -149,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -149,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v7
using
DstData0
=
remove_cvref_t
<
decltype
(
DstDatas
{}[
I0
])
>
;
using
DstData0
=
remove_cvref_t
<
decltype
(
DstDatas
{}[
I0
])
>
;
element_op_
(
dst_vectors
[
I0
]
.
template
AsType
<
DstData0
>()(
i
),
element_op_
(
dst_vectors
(
I0
)
.
template
AsType
<
DstData0
>()(
i
),
src_vectors
[
I0
].
template
AsType
<
SrcData0
>()[
i
],
src_vectors
[
I0
].
template
AsType
<
SrcData0
>()[
i
],
src_vectors
[
I1
].
template
AsType
<
SrcData1
>()[
i
],
src_vectors
[
I1
].
template
AsType
<
SrcData1
>()[
i
],
src_vectors
[
I2
].
template
AsType
<
SrcData2
>()[
i
]);
src_vectors
[
I2
].
template
AsType
<
SrcData2
>()[
i
]);
...
@@ -157,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -157,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from buf_vectors into dst_bufs
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cv_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
using
dst_vector_t
=
typename
remove_cv
ref
_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
...
@@ -230,39 +232,53 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -230,39 +232,53 @@ struct ThreadwiseTensorSliceTransfer_v7
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_step_idx
)
const
Index
&
src_slice_origin_step_idx
,
Number
<
ISrc
>
iSrc
)
{
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
const
auto
adjusted_step_idx
=
?
src_slice_origin_step_idx
SrcResetCoordinateAfterRunFlags
::
At
(
i
)
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
i
Src
],
adjusted_step_idx
);
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
adjusted_step
);
move_tensor_coordinate
(
src_descs
[
iSrc
],
src_coords_
(
iSrc
),
adjusted_step
);
});
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
dst_slice_origin_step_idx
)
const
Index
&
dst_slice_origin_step_idx
,
Number
<
IDst
>
iDst
)
{
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
const
auto
adjusted_step_idx
=
?
dst_slice_origin_step_idx
DstResetCoordinateAfterRunFlags
::
At
(
i
)
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
Dst
],
adjusted_step_idx
);
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
adjusted_step
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
});
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveAllSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_step_idx
)
{
static_for
<
0
,
nSrc
,
1
>
{}(
[
&
](
auto
i
)
{
MoveSrcSliceWindow
(
src_descs
,
src_slice_origin_step_idx
,
i
);
});
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveAllDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
dst_slice_origin_step_idx
)
{
static_for
<
0
,
nDst
,
1
>
{}(
[
&
](
auto
i
)
{
MoveDstSliceWindow
(
dst_descs
,
dst_slice_origin_step_idx
,
i
);
});
}
}
private:
private:
...
...
include/ck/utility/data_type.hpp
View file @
f9b92b1e
#pragma once
#pragma once
#include "statically_indexed_array.hpp"
#include "statically_indexed_array.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/utility/tuple.hpp
View file @
f9b92b1e
#ifndef CK_TUPLE_HPP
#pragma once
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
...
@@ -25,9 +24,9 @@ struct TupleElementKeyData
...
@@ -25,9 +24,9 @@ struct TupleElementKeyData
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
#endif
#endif
template
<
template
<
typename
T
,
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
{
}
}
...
@@ -36,7 +35,8 @@ struct TupleElementKeyData
...
@@ -36,7 +35,8 @@ struct TupleElementKeyData
};
};
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
const
Data
&
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
{
return
static_cast
<
const
Data
&>
(
x
.
mData
);
return
static_cast
<
const
Data
&>
(
x
.
mData
);
}
}
...
@@ -179,13 +179,13 @@ struct Tuple<>
...
@@ -179,13 +179,13 @@ struct Tuple<>
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
};
template
<
index_t
I
,
typename
TTuple
>
template
<
index_t
I
,
typename
TTuple
>
struct
tuple_element
struct
tuple_element
{
{
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
};
};
template
<
index_t
I
,
typename
TTuple
>
template
<
index_t
I
,
typename
TTuple
>
using
tuple_element_t
=
typename
tuple_element
<
I
,
TTuple
>::
type
;
using
tuple_element_t
=
typename
tuple_element
<
I
,
TTuple
>::
type
;
template
<
typename
...
Xs
>
template
<
typename
...
Xs
>
...
@@ -202,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
...
@@ -202,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
}
}
}
// namespace ck
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment