Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1ca98e75
Commit
1ca98e75
authored
Aug 26, 2024
by
aska-0096
Browse files
tempsave
parent
9a99c841
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
28 deletions
+27
-28
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+27
-28
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
1ca98e75
...
@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2
// 1. DstDesc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
// 3. dst_slice_origin_idx is known at compile-time
template
<
typename
SrcData
,
template
<
typename
SrcData
s
,
typename
DstData
,
typename
DstData
s
,
typename
SrcDesc
,
typename
SrcDesc
s
,
typename
DstDesc
,
typename
DstDesc
s
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
s
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
s
,
bool
SrcResetCoordinateAfterRun
,
bool
SrcResetCoordinateAfterRun
,
bool
InvalidElementAsNaN
=
false
,
bool
InvalidElementAsNaN
=
false
,
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v2r1
struct
ThreadwiseTensorSliceTransfer_v2r1
{
{
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
>::
value
)
||
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
s
>::
value
)
||
(
!
InvalidElementAsNaN
),
(
!
InvalidElementAsNaN
),
"Filling invalid element as NaN is only for floating point types"
);
"Filling invalid element as NaN is only for floating point types"
);
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
n
Src
=
Src
Descs
::
Size
();
static
constexpr
index_t
n
Dst
=
Dst
Descs
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
...
@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1
...
@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v2
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_idx
)
const
Indexs
&
src_slice_origin_idxs
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
"wrong! Not divisible"
);
src_coords_
(
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
src_desc
[
i
],
src_slice_origin_idx
[
i
]);
},
nSrc
);)
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
typename
DstSliceOriginIdxs
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
,
const
DstSliceOriginIdxs
&
,
DstBuffers
&
dst_bufs
)
{
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
}
static_assert
(
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
IsKnownAtCompileTime
(),
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
"wrong! DstDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstSliceOriginIdx
s
>
>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
tuple_element_t
<
i
.
value
,
DstBuffer
>::
type
>
,
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDatas
>>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
});
// DstDesc and dst_slice_origin_idx are known at compile-time
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment