Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d73dd3e
"...resnet50_tensorflow.git" did not exist on "97c19f540b29da64048bbff1b4e7deb32024bb38"
Commit
5d73dd3e
authored
Sep 13, 2023
by
Jing Zhang
Browse files
add gridwise_multi_abd
parent
a66d14ed
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2574 additions
and
63 deletions
+2574
-63
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
+72
-63
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+205
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
.../tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
+60
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+766
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+1109
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+362
-0
No files found.
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
View file @
5d73dd3e
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_
ab
d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
...
@@ -59,56 +59,65 @@ using BLayout = Col;
...
@@ -59,56 +59,65 @@ using BLayout = Col;
using
DLayout
=
Row
;
using
DLayout
=
Row
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
struct
MultiATest
{
template
<
typename
A
,
typename
A0
,
typename
A1
>
__host__
__device__
constexpr
void
operator
()(
A
&
a
,
const
A0
&
a0
,
const
A1
&
a1
)
const
{
a
=
(
a0
+
a1
)
/
2
;
}
};
using
AElementOp
=
MultiATest
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AlphaBetaAdd
;
using
CDEElementOp
=
AlphaBetaAdd
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceOpInstance
=
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Xdl_CShuffle
<
ALayout
,
ck
::
Tuple
<
ALayout
,
ALayout
>
,
BLayout
,
ck
::
Tuple
<
BLayout
>
,
ck
::
Tuple
<
DLayout
>
,
ck
::
Tuple
<
DLayout
>
,
ELayout
,
ELayout
,
ADataType
,
ck
::
Tuple
<
ADataType
,
ADataType
>
,
BDataType
,
ck
::
Tuple
<
BDataType
>
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
EDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CDEElementOp
,
CDEElementOp
,
GemmSpec
,
GemmSpec
,
1
,
1
,
256
,
256
,
256
,
256
,
128
,
128
,
32
,
32
,
8
,
8
,
8
,
8
,
32
,
32
,
32
,
32
,
4
,
4
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -232,21 +241,21 @@ int main(int argc, char* argv[])
...
@@ -232,21 +241,21 @@ int main(int argc, char* argv[])
// do GEMM
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
auto
argument
=
device_op
.
MakeArgument
(
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
a_device_buf
.
GetDeviceBuffer
(),
a_device_buf
.
GetDeviceBuffer
()
}
,
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
b_device_buf
.
GetDeviceBuffer
()
}
,
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
e_device_buf
.
GetDeviceBuffer
(),
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
std
::
array
<
ck
::
index_t
,
2
>
{
StrideA
,
StrideA
}
,
StrideB
,
std
::
array
<
ck
::
index_t
,
1
>
{
StrideB
}
,
std
::
array
<
ck
::
index_t
,
1
>
{
StrideD
},
std
::
array
<
ck
::
index_t
,
1
>
{
StrideD
},
StrideE
,
StrideE
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
cde_element_op
);
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -278,14 +287,14 @@ int main(int argc, char* argv[])
...
@@ -278,14 +287,14 @@ int main(int argc, char* argv[])
BDataType
,
BDataType
,
CShuffleDataType
,
CShuffleDataType
,
AccDataType
,
AccDataType
,
A
ElementOp
,
B
ElementOp
,
BElementOp
,
BElementOp
,
PassThrough
>
;
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a
_element_op
,
b_element_op
,
PassThrough
{});
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
b
_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
0 → 100644
View file @
5d73dd3e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
struct
ThreadGroupTensorSliceTransfer_v7r2
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r2
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
);
}
}
template
<
typename
DstBuffers
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
step
)
{
MoveSrcSliceWindow
(
src_descs
,
Number
<
0
>
{},
step
);
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
step
)
{
MoveDstSliceWindow
(
dst_descs
,
Number
<
0
>
{},
step
);
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r2
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
0 → 100644
View file @
5d73dd3e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A0[M, K], B0[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleABD
:
public
BaseOperator
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
ck
::
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
5d73dd3e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
5d73dd3e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
0 → 100644
View file @
5d73dd3e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template
<
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
struct
ThreadwiseTensorSliceTransfer_v7r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
nDst
=
DstDescs
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
// return a tuple of coordiantes for a tuple of tensor
template
<
typename
Descs
,
typename
Indices
,
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
static
constexpr
auto
MakeCoordinates
(
const
Descs
&
descs
,
const
Indices
&
indices
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
Number
<
Descs
::
Size
()
>
{});
}
using
SrcCoords
=
decltype
(
MakeCoordinates
(
SrcDescs
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{}));
using
DstCoords
=
decltype
(
MakeCoordinates
(
DstDescs
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SrcSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
static
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
using
DstSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7r2
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetSrcSliceOrigins
(
const
SrcDescs
&
src_descs
,
const
Indices
&
src_slice_origin_idxs
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
});
}
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetDstSliceOrigins
(
const
DstDescs
&
dst_descs
,
const
Indices
&
dst_slice_origin_idxs
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
});
}
template
<
typename
DataTypes
,
index_t
ScalarPerVector
>
__device__
static
auto
generate_vectors
()
{
auto
data_types
=
DataTypes
{};
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
},
Number
<
num
>
{});
}
#if 1
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
{
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
();
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
src_coords_
[
i
]);
src_vectors_tuple_
(
iAccess
)(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
});
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SrcSpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
make_tensor_coordinate_step
(
src_descs
[
i
],
forward_step
));
});
}
});
// move coordinate back to slice origin (or not)
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
SrcResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
}
});
}
#endif
#if 1
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
DstBuffers
,
enable_if_t
<
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
src_vectors_tuple_
[
iAccess
];
auto
dst_vectors
=
generate_vectors
<
DstDatas
,
DstScalarPerVector
>
();
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
dst_coords_
[
i
]);
constexpr
InMemoryDataOperationEnum
DstInMemOp
=
static_cast
<
InMemoryDataOperationEnum
>
(
DstInMemOps
::
At
(
i
.
value
));
dst_bufs
(
i
).
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coords_
[
i
].
GetOffset
(),
is_dst_valid
,
dst_vectors
[
i
].
template
AsType
<
dst_vector_t
>()[
I0
]);
});
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
DstSpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
make_tensor_coordinate_step
(
dst_descs
[
i
],
forward_step
));
});
}
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
DstResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
}
});
}
#endif
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
if
constexpr
(
num_access
==
0
)
{
return
typename
SrcSpaceFillingCurve
::
Index
{};
}
else
{
return
SrcSpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
if
constexpr
(
num_access
==
0
)
{
return
typename
DstSpaceFillingCurve
::
Index
{};
}
else
{
return
DstSpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
iSrc
],
adjusted_step_idx
);
move_tensor_coordinate
(
src_descs
[
iSrc
],
src_coords_
(
iSrc
),
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
iDst
],
adjusted_step_idx
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
}
private:
using
SrcVectorsType
=
decltype
(
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
());
using
DstVectorsType
=
decltype
(
generate_vectors
<
DstDatas
,
DstScalarPerVector
>
());
static
constexpr
auto
num_access
=
SrcSpaceFillingCurve
::
GetNumOfAccess
();
StaticallyIndexedArray
<
SrcVectorsType
,
num_access
>
src_vectors_tuple_
;
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment