Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
760a234f
Commit
760a234f
authored
Dec 03, 2020
by
Chao Liu
Browse files
use StaticallyIndexedArray for buffer in threadwise copy, in order to get rid of alloca in IR
parent
70d06fa9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
508 additions
and
1395 deletions
+508
-1395
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+35
-90
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+62
-546
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+69
-448
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+332
-299
driver/include/conv_common.hpp
driver/include/conv_common.hpp
+9
-11
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
760a234f
...
...
@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// GEMM
#if 1
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
true
,
// move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
...
...
@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global
,
integral_constant
<
bool
,
false
>
{});
}
#else
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v2
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
);
#endif
}
};
...
...
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
760a234f
...
...
@@ -9,514 +9,53 @@
namespace
ck
{
// this version does not have scratch memory issue, which is good, but I don't know why
template
<
index_t
BlockSize
,
typename
BlockSrcData
,
typename
BlockDstData
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDstDimAccessOrder
,
index_t
SrcDstVectoReadDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcDataStride
,
index_t
DstDataStride
>
struct
BlockwiseDynamicTensorSliceTransfer_v1r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v1r1
(
const
BlockSrcDesc
&
block_src_desc
,
const
Index
&
src_block_slice_origin
,
const
BlockDstDesc
&
block_dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
block_src_desc
,
make_zero_multi_index
<
nDim
>
(),
block_dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockDstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
thread_cluster_desc_
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
__device__
void
Run
(
const
BlockSrcData
*
p_block_src
,
BlockDstData
*
p_block_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
p_block_src
,
p_block_dst
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseDynamicTensorSliceTransfer_v1r1
<
BlockSrcDesc
,
BlockDstDesc
,
ThreadSliceLengths
,
SrcDstDimAccessOrder
,
SrcDstVectoReadDim
,
SrcDataPerRead
,
DstDataPerWrite
,
SrcAddressSpace
,
DstAddressSpace
,
DstInMemOp
,
SrcDataStride
,
DstDataStride
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
// this version tend to have scratch memory issue, due to:
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r1::Run() constructs new tensor coordinate
template
<
index_t
BlockSize
,
typename
BlockSrcData
,
typename
BlockDstData
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorReadDim
,
index_t
DstVectorWriteDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcDataStride
,
index_t
DstDataStride
>
struct
BlockwiseDynamicTensorSliceTransfer_v2r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v2r1
(
const
BlockSrcDesc
&
block_src_desc
,
const
Index
&
src_block_slice_origin
,
const
BlockDstDesc
&
block_dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_read_
(
block_src_desc
,
make_zero_multi_index
<
nDim
>
(),
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
()),
threadwise_write_
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
(),
block_dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockDstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
thread_cluster_desc_
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
threadwise_read_
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
threadwise_read_
.
SetDstSliceOrigin
(
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetSrcSliceOrigin
(
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
__device__
void
RunRead
(
const
BlockSrcData
*
p_block_src
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
Run
(
p_block_src
,
p_thread_buffer_
);
}
}
__device__
void
RunWrite
(
BlockDstData
*
p_block_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_write_
.
Run
(
p_thread_buffer_
,
p_block_dst
);
}
}
__device__
void
Run
(
const
BlockSrcData
*
p_block_src
,
BlockDstData
*
p_block_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
Run
(
p_block_src
,
p_thread_buffer_
);
// if there is type conversion, it's done during write
threadwise_write_
.
Run
(
p_thread_buffer_
,
p_block_dst
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
MoveSrcSliceWindow
(
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_write_
.
MoveDstSliceWindow
(
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
static
constexpr
auto
thread_buffer_desc_
=
make_dynamic_naive_tensor_descriptor_packed
<
nDim
>
(
to_multi_index
(
ThreadSliceLengths
{}));
using
ThreadwiseRead
=
ThreadwiseDynamicTensorSliceTransfer_v1r1
<
BlockSrcDesc
,
decltype
(
thread_buffer_desc_
),
ThreadSliceLengths
,
SrcDimAccessOrder
,
SrcVectorReadDim
,
SrcDataPerRead
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
>
;
using
ThreadwiseWrite
=
ThreadwiseDynamicTensorSliceTransfer_v1r1
<
decltype
(
thread_buffer_desc_
),
BlockDstDesc
,
ThreadSliceLengths
,
DstDimAccessOrder
,
DstVectorWriteDim
,
1
,
DstDataPerWrite
,
AddressSpace
::
Vgpr
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
>
;
ThreadwiseRead
threadwise_read_
;
ThreadwiseWrite
threadwise_write_
;
static
constexpr
index_t
thread_buffer_element_size_
=
thread_buffer_desc_
.
GetElementSpaceSize
();
BlockSrcData
p_thread_buffer_
[
thread_buffer_element_size_
];
};
// this version does following things to avoid scratch memory issue
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2::Run() does not construct new tensor coordinate
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
BlockSrcData
,
typename
BlockDstData
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorReadDim
,
index_t
DstVectorWriteDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcDataStride
,
index_t
DstDataStride
>
struct
BlockwiseDynamicTensorSliceTransfer_v2r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v2r2
(
const
BlockSrcDesc
&
block_src_desc
,
const
Index
&
src_block_slice_origin
,
const
BlockDstDesc
&
block_dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_read_
(
block_src_desc
,
make_zero_multi_index
<
nDim
>
(),
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
()),
threadwise_write_
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
(),
block_dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockDstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
thread_cluster_desc_
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
threadwise_read_
.
SetSrcSliceOrigin
(
block_src_desc
,
src_block_slice_origin
+
thread_data_id_begin
);
threadwise_read_
.
SetDstSliceOrigin
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetSrcSliceOrigin
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetDstSliceOrigin
(
block_dst_desc
,
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
__device__
void
RunRead
(
const
BlockSrcDesc
&
block_src_desc
,
const
BlockSrcData
*
p_block_src
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
Run
(
block_src_desc
,
p_block_src
,
thread_buffer_desc_
,
p_thread_buffer_
);
}
}
__device__
void
RunWrite
(
const
BlockDstDesc
&
block_dst_desc
,
BlockDstData
*
p_block_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_write_
.
Run
(
thread_buffer_desc_
,
p_thread_buffer_
,
block_dst_desc
,
p_block_dst
);
}
}
__device__
void
Run
(
const
BlockSrcDesc
&
block_src_desc
,
const
BlockSrcData
*
p_block_src
,
const
BlockDstDesc
&
block_dst_desc
,
BlockDstData
*
p_block_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
Run
(
block_src_desc
,
p_block_src
,
thread_buffer_desc_
,
p_thread_buffer_
);
// if there is type conversion, it's done during write
threadwise_write_
.
Run
(
thread_buffer_desc_
,
p_thread_buffer_
,
block_dst_desc
,
p_block_dst
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
BlockSrcDesc
&
block_src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_read_
.
MoveSrcSliceWindow
(
block_src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
BlockDstDesc
&
block_dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_write_
.
MoveDstSliceWindow
(
block_dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
static
constexpr
auto
thread_buffer_desc_
=
make_dynamic_naive_tensor_descriptor_packed
<
nDim
>
(
to_multi_index
(
ThreadSliceLengths
{}));
using
ThreadwiseRead
=
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
BlockSrcDesc
,
decltype
(
thread_buffer_desc_
),
ThreadSliceLengths
,
SrcDimAccessOrder
,
SrcVectorReadDim
,
SrcDataPerRead
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
>
;
using
ThreadwiseWrite
=
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
decltype
(
thread_buffer_desc_
),
BlockDstDesc
,
ThreadSliceLengths
,
DstDimAccessOrder
,
DstVectorWriteDim
,
1
,
DstDataPerWrite
,
AddressSpace
::
Vgpr
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
>
;
ThreadwiseRead
threadwise_read_
;
ThreadwiseWrite
threadwise_write_
;
static
constexpr
index_t
thread_buffer_element_size_
=
thread_buffer_desc_
.
GetElementSpaceSize
();
BlockSrcData
p_thread_buffer_
[
thread_buffer_element_size_
];
};
// this version does following things to avoid scratch memory issue
// 1. BlockwiseDynamicTensorSliceTransfer_v2r3 doesn't allocate thread buffer (array) as member
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v1r2::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
BlockSrcData
,
typename
BlockDstData
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVector
Read
Dim
,
index_t
DstVector
Write
Dim
,
index_t
Src
DataPerRead
,
index_t
Dst
DataPerWrite
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
Src
ScalarPerVector
,
index_t
Dst
ScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcDataStride
,
index_t
DstDataStride
,
index_t
ThreadTransferMoveBackSrcCoord
=
true
,
index_t
ThreadTransferMoveBackDstCoord
=
true
>
struct
BlockwiseDynamicTensorSliceTransfer_v2r3
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
index_t
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseDynamicTensorSliceTransfer_v4
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
remove_cv_t
<
BlockSrcDesc
>>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v2r3
(
const
BlockSrcDesc
&
block_src_desc
,
const
Index
&
src_block_slice_origin
,
const
BlockDstDesc
&
block_dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_read_
(
block_src_desc
,
make_zero_multi_index
<
nDim
>
(),
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
()),
threadwise_write_
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
(),
block_dst_desc
,
make_zero_multi_index
<
nDim
>
())
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockSrc
Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
BlockDstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
Thread
Slice
Lengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
Dst
Desc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
Thread
Cluster
Lengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
...
...
@@ -533,13 +72,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r3
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
threadwise_read_
.
SetSrcSliceOrigin
(
block_src_desc
,
src_block_slice_origin
+
thread_data_id_begin
);
threadwise_read_
.
SetDstSliceOrigin
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetSrcSliceOrigin
(
thread_buffer_desc_
,
make_zero_multi_index
<
nDim
>
());
threadwise_write_
.
SetDstSliceOrigin
(
block_dst_desc
,
dst_block_slice_origin
+
thread_data_id_begin
);
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_id_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
...
...
@@ -551,86 +87,66 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r3
return
thread_cluster_id
*
ThreadSliceLengths
{};
}
__device__
void
RunRead
(
const
BlockSrcDesc
&
block_src_desc
,
const
BlockSrcData
*
p_block_src
,
BlockSrcData
*
p_thread_buffer
)
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_
read_
.
Run
(
block_src_desc
,
p_block_src
,
thread_buffer_desc_
,
p_thread_buffer
);
threadwise_
transfer_
.
RunRead
(
src_desc
,
p_src
);
}
}
__device__
void
RunWrite
(
const
BlockDstDesc
&
block_dst_desc
,
BlockDstData
*
p_block_dst
,
BlockDstData
*
p_thread_buffer
)
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_write_
.
Run
(
thread_buffer_desc_
,
p_thread_buffer
,
block_dst_desc
,
p_block_dst
);
threadwise_transfer_
.
RunWrite
(
dst_desc
,
p_dst
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
Block
SrcDesc
&
block_
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_
read
_
.
MoveSrcSliceWindow
(
block_
src_desc
,
step
);
threadwise_
transfer
_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
Block
DstDesc
&
block_
dst_desc
,
const
Index
&
step
)
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_
write
_
.
MoveDstSliceWindow
(
block_
dst_desc
,
step
);
threadwise_
transfer
_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
static
constexpr
auto
thread_buffer_desc_
=
make_dynamic_naive_tensor_descriptor_packed
<
nDim
>
(
to_multi_index
(
ThreadSliceLengths
{}));
using
ThreadwiseTransfer
=
ThreadwiseDynamicTensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
SrcAddressSpace
,
DstAddressSpace
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
using
ThreadwiseRead
=
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
BlockSrcDesc
,
decltype
(
thread_buffer_desc_
),
ThreadSliceLengths
,
SrcDimAccessOrder
,
SrcVectorReadDim
,
SrcDataPerRead
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
,
ThreadTransferMoveBackSrcCoord
,
true
>
;
using
ThreadwiseWrite
=
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
decltype
(
thread_buffer_desc_
),
BlockDstDesc
,
ThreadSliceLengths
,
DstDimAccessOrder
,
DstVectorWriteDim
,
1
,
DstDataPerWrite
,
AddressSpace
::
Vgpr
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
,
true
,
ThreadTransferMoveBackDstCoord
>
;
ThreadwiseRead
threadwise_read_
;
ThreadwiseWrite
threadwise_write_
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
760a234f
...
...
@@ -32,6 +32,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -39,6 +40,7 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
...
...
@@ -130,28 +132,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// A matrix blockwise copy
auto
a_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v
2r3
<
BlockSize
,
Floa
t
,
Float
,
decltype
(
a_k_m_global_desc
)
,
decltype
(
a_k_m_block_desc
)
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransfer
Dst
ScalarPerVector
_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
,
true
>
(
BlockwiseDynamicTensorSliceTransfer_v
4
<
BlockSize
,
InMemoryDataOperation
::
Se
t
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
Float
,
Float
,
decltype
(
a_k_m_global_desc
)
,
decltype
(
a_k_m_block_desc
)
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
1
,
ABlockTransfer
Src
ScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k_m_global_desc
,
make_multi_index
(
0
,
m_block_data_on_global
),
a_k_m_block_desc
,
...
...
@@ -159,32 +161,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// B matrix blockwise copy
auto
b_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v2r3
<
BlockSize
,
Float
,
Float
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
1
,
1
,
#if 0
true.
#else
false
,
#endif
true
>
(
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
Float
,
Float
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_on_global
),
b_k_n_block_desc
,
...
...
@@ -253,25 +251,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
#if 0
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr
auto
b_block_slice_copy_step
=
b_block_copy
.
threadwise_read_
.
GetCoordinateStepBack
()
+
make_multi_index
(
KPerBlock
,
0
);
#endif
// LDS double buffer: preload data into LDS
{
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()]
;
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()]
;
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
)
;
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
)
;
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
);
}
// LDS double buffer: main body
...
...
@@ -298,19 +286,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads
();
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
);
}
}
...
...
@@ -323,21 +308,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads
();
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS double buffer: load last data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on 2nd-last data
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
...
...
@@ -378,6 +358,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
AccFloat
,
Float
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
Sequence
<
MRepeat
,
MPerThread
,
NRepeat
,
NPerThread
>
,
...
...
@@ -389,13 +371,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
1
,
1
>
(
c_m0_m1_n0_n1_thread_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
1
,
true
,
true
>
(
c_m0_m1_n0_n1_thread_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
p_c_thread
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
);
}
}
...
...
@@ -423,368 +407,5 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
{});
}
};
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseDynamicGemm_km_kn_mn_v2
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockTransferDstScalarPerVector_M
,
BBlockTransferDstScalarPerVector_N
,
MPerThread
,
NPerThread
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned
<
2
>
(
make_multi_index
(
KPerBlock
,
MPerBlock
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned
<
2
>
(
make_multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
const
index_t
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
index_t
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
index_t
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
#else
// Hack: this force result into SGPR
const
index_t
m_block_work_num
=
__builtin_amdgcn_readfirstlane
(
M
/
MPerBlock
);
const
index_t
n_block_work_num
=
__builtin_amdgcn_readfirstlane
(
N
/
NPerBlock
);
#endif
#if 0
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const
index_t
m_block_work_id
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
n_block_work_num
);
const
index_t
n_block_work_id
=
get_block_1d_id
()
-
m_block_work_id
*
n_block_work_num
;
#endif
const
index_t
m_block_data_on_global
=
m_block_work_id
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
n_block_work_id
*
NPerBlock
;
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockTransferDstScalarPerVector_M
,
BBlockTransferDstScalarPerVector_N
,
MPerThread
,
NPerThread
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned
<
2
>
(
make_multi_index
(
KPerBlock
,
MPerBlock
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned
<
2
>
(
make_multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
// A matrix blockwise copy
auto
a_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v2r3
<
BlockSize
,
Float
,
Float
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
Sequence
<
KPerBlock
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
,
true
>
(
a_k_m_global_desc
,
make_multi_index
(
0
,
m_block_data_on_global
),
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
auto
b_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v2r3
<
BlockSize
,
Float
,
Float
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
1
,
1
,
#if 0
true.
#else
false
,
#endif
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_on_global
),
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
index_t
a_k_m_block_mtx_stride
=
a_k_m_block_desc
.
CalculateOffset
(
make_multi_index
(
1
,
0
))
-
a_k_m_block_desc
.
CalculateOffset
(
make_multi_index
(
0
,
0
));
constexpr
index_t
b_k_n_block_mtx_stride
=
b_k_n_block_desc
.
CalculateOffset
(
make_multi_index
(
1
,
0
))
-
b_k_n_block_desc
.
CalculateOffset
(
make_multi_index
(
0
,
0
));
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
Number
<
a_k_m_block_mtx_stride
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
Number
<
b_k_n_block_mtx_stride
>
{});
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
MRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
NRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{});
const
auto
block_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
Float
*
p_a_block
=
p_shared_block
;
Float
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr
auto
b_block_slice_copy_step
=
b_block_copy
.
threadwise_read_
.
GetCoordinateStepBack
()
+
make_multi_index
(
KPerBlock
,
0
);
#endif
// preload data into LDS
{
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block
,
p_b_thread_buffer
);
}
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
// load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
__syncthreads
();
// GEMM on current data
block_gemm
.
Run
(
p_a_block
,
p_b_block
,
p_c_thread
);
__syncthreads
();
// store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block
,
p_b_thread_buffer
);
}
// tail
{
__syncthreads
();
block_gemm
.
Run
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
// output: register to global memory
{
constexpr
index_t
M1
=
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
N1
=
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed
<
4
>
(
make_multi_index
(
MRepeat
,
MPerThread
,
NRepeat
,
NPerThread
));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
block_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseDynamicTensorSliceTransfer_v1r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
Sequence
<
MRepeat
,
MPerThread
,
NRepeat
,
NPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
1
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
1
,
1
>
(
c_m0_m1_n0_n1_thread_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
p_c_thread
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
);
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_shared_block
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
760a234f
...
...
@@ -7,185 +7,12 @@
namespace
ck
{
// this version tends to have scratch memory issue, due to:
// 1. It keeps reference to tensor descriptor
// 2. It constructs new tensor coordinate in this->Run()
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SrcDstDimAccessOrder
,
index_t
SrcDstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
>
struct
ThreadwiseDynamicTensorSliceTransfer_v1r1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_dynamic_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_dynamic_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v1r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
:
src_desc_
(
src_desc
),
src_slice_origin_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_desc_
(
dst_desc
),
dst_slice_origin_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
}
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v1r1
()
:
ThreadwiseDynamicTensorSliceTransfer_v1r1
(
SrcDesc
{},
make_zero_multi_index
<
nDim
>
(),
DstDesc
{},
make_zero_multi_index
<
nDim
>
())
{
}
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
// comment: construction tensor coordinate here tends to cause scratch memory issue
auto
src_coord
=
src_slice_origin_
;
auto
dst_coord
=
dst_slice_origin_
;
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const
auto
src_step_0_p1
=
make_dynamic_tensor_coordinate_step
(
src_desc_
,
make_multi_index
(
0
,
1
));
const
auto
src_step_0_m1
=
make_dynamic_tensor_coordinate_step
(
src_desc_
,
make_multi_index
(
0
,
-
1
));
const
auto
src_step_p1_0
=
make_dynamic_tensor_coordinate_step
(
src_desc_
,
make_multi_index
(
1
,
0
));
const
auto
src_step_m1_0
=
make_dynamic_tensor_coordinate_step
(
src_desc_
,
make_multi_index
(
-
1
,
0
));
const
auto
dst_step_0_p1
=
make_dynamic_tensor_coordinate_step
(
dst_desc_
,
make_multi_index
(
0
,
1
));
const
auto
dst_step_0_m1
=
make_dynamic_tensor_coordinate_step
(
dst_desc_
,
make_multi_index
(
0
,
-
1
));
const
auto
dst_step_p1_0
=
make_dynamic_tensor_coordinate_step
(
dst_desc_
,
make_multi_index
(
1
,
0
));
const
auto
dst_step_m1_0
=
make_dynamic_tensor_coordinate_step
(
dst_desc_
,
make_multi_index
(
-
1
,
0
));
constexpr
index_t
Len0
=
SliceLengths
{}[
0
];
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
bool
forward_dim0
=
true
;
bool
forward_dim1
=
true
;
// hardcoded for 2d loop for now
#pragma unroll
for
(
index_t
i0
=
0
;
i0
<
Len0
;
++
i0
)
{
#pragma unroll
for
(
index_t
i1
=
0
;
i1
<
Len1
;
++
i1
)
{
// do work
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
DstAddressSpace
,
DstInMemOp
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
>
(
p_src
,
src_coord
.
GetOffset
(),
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc_
,
src_coord
),
src_desc_
.
GetElementSpaceSize
(),
p_dst
,
dst_coord
.
GetOffset
(),
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc_
,
dst_coord
),
dst_desc_
.
GetElementSpaceSize
());
// move dim1 iterator
if
(
i1
<
Len1
-
1
)
{
if
(
forward_dim1
)
{
move_dynamic_tensor_coordinate
(
src_desc_
,
src_coord
,
src_step_0_p1
);
move_dynamic_tensor_coordinate
(
dst_desc_
,
dst_coord
,
dst_step_0_p1
);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc_
,
src_coord
,
src_step_0_m1
);
move_dynamic_tensor_coordinate
(
dst_desc_
,
dst_coord
,
dst_step_0_m1
);
}
}
}
// switch dim1 iteration direction
forward_dim1
=
!
forward_dim1
;
// move dim0 iterator
if
(
i0
<
Len0
-
1
)
{
if
(
forward_dim0
)
{
move_dynamic_tensor_coordinate
(
src_desc_
,
src_coord
,
src_step_p1_0
);
move_dynamic_tensor_coordinate
(
dst_desc_
,
dst_coord
,
dst_step_p1_0
);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc_
,
src_coord
,
src_step_m1_0
);
move_dynamic_tensor_coordinate
(
dst_desc_
,
dst_coord
,
dst_step_m1_0
);
}
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
Index
&
src_slice_origin_idx
)
{
src_slice_origin_
=
make_dynamic_tensor_coordinate
(
src_desc_
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
Index
&
dst_slice_origin_idx
)
{
dst_slice_origin_
=
make_dynamic_tensor_coordinate
(
dst_desc_
,
dst_slice_origin_idx
);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
Index
&
src_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
src_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
src_desc_
,
src_slice_origin_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc_
,
src_slice_origin_
,
src_slice_origin_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
Index
&
dst_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
dst_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc_
,
dst_slice_origin_step_idx
);
move_dynamic_tensor_coordinate
(
dst_desc_
,
dst_slice_origin_
,
dst_slice_origin_step
);
}
private:
const
SrcDesc
&
src_desc_
;
const
DstDesc
&
dst_desc_
;
SrcCoord
src_slice_origin_
;
DstCoord
dst_slice_origin_
;
};
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
template
<
typename
SrcDesc
,
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SrcDstDimAccessOrder
,
...
...
@@ -197,8 +24,12 @@ template <typename SrcDesc,
InMemoryDataOperation
DstInMemOp
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
MoveBackSrcCoord
=
true
,
bool
MoveBackDstCoord
=
true
>
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct
ThreadwiseDynamicTensorSliceTransfer_v1r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
...
@@ -225,7 +56,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
{
}
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_slice_origin_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_slice_origin_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
{
...
...
@@ -256,13 +96,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
#pragma unroll
for
(
index_t
i0
=
0
;
i0
<
Len0
;
++
i0
)
for
(
index_t
i
ter
0
=
0
;
i
ter
0
<
Len0
;
++
i
ter
0
)
{
#pragma unroll
for
(
index_t
i1
=
0
;
i1
<
Len1
;
++
i1
)
for
(
index_t
i
ter
1
=
0
;
i
ter
1
<
Len1
;
++
i
ter
1
)
{
#if 1 // debug
// do work
// do work
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
...
...
@@ -280,68 +119,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
),
dst_desc
.
GetElementSpaceSize
());
#else
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
&&
DstAddressSpace
==
AddressSpace
::
Vgpr
)
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
))
{
const
SrcData
tmp
=
amd_buffer_load
<
SrcData
,
1
>
(
p_src
,
src_slice_origin_
.
GetOffset
(),
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
),
src_desc
.
GetElementSpaceSize
());
const
index_t
dst_offset
=
dst_slice_origin_
.
GetOffset
();
p_dst
[
dst_offset
]
=
tmp
;
}
}
else
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Vgpr
&&
DstAddressSpace
==
AddressSpace
::
Global
)
{
const
SrcData
zeros
=
0
;
const
bool
src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
);
const
bool
dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
);
amd_buffer_store
<
SrcData
,
1
>
(
src_valid
?
&
(
p_src
[
src_slice_origin_
.
GetOffset
()])
:
&
zeros
,
p_dst
,
dst_slice_origin_
.
GetOffset
(),
dst_valid
,
dst_desc
.
GetElementSpaceSize
());
}
else
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
))
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
))
{
p_dst
[
dst_slice_origin_
.
GetOffset
()]
=
p_src
[
src_slice_origin_
.
GetOffset
()];
}
else
{
p_dst
[
dst_slice_origin_
.
GetOffset
()]
=
0
;
}
}
}
#endif
// move dim1 iterator
if
(
i1
<
Len1
-
1
)
if
(
i
ter
1
<
Len1
-
1
)
{
bool
forward_dim1
=
(
i0
%
2
==
0
);
bool
forward_dim1
=
(
i
ter
0
%
2
==
0
);
if
(
forward_dim1
)
{
...
...
@@ -361,7 +143,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
// move dim0 iterator
if
(
i0
<
Len0
-
1
)
if
(
i
ter
0
<
Len0
-
1
)
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0
);
...
...
@@ -416,22 +198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr
index_t
Len2
=
SliceLengths
{}[
2
];
constexpr
index_t
Len3
=
SliceLengths
{}[
3
];
bool
forward_dim0
=
true
;
bool
forward_dim1
=
true
;
bool
forward_dim2
=
true
;
bool
forward_dim3
=
true
;
#pragma unroll
for
(
index_t
i0
=
0
;
i0
<
Len0
;
++
i0
)
for
(
index_t
i
ter
0
=
0
;
i
ter
0
<
Len0
;
++
i
ter
0
)
{
#pragma unroll
for
(
index_t
i1
=
0
;
i1
<
Len1
;
++
i1
)
for
(
index_t
i
ter
1
=
0
;
i
ter
1
<
Len1
;
++
i
ter
1
)
{
#pragma unroll
for
(
index_t
i2
=
0
;
i2
<
Len2
;
++
i2
)
for
(
index_t
i
ter
2
=
0
;
i
ter
2
<
Len2
;
++
i
ter
2
)
{
#pragma unroll
for
(
index_t
i3
=
0
;
i3
<
Len3
;
++
i3
)
for
(
index_t
i
ter
3
=
0
;
i
ter
3
<
Len3
;
++
i
ter
3
)
{
// do work
transfer_data
<
SrcData
,
...
...
@@ -453,8 +230,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
dst_desc
.
GetElementSpaceSize
());
// move dim1 iterator
if
(
i3
<
Len3
-
1
)
if
(
i
ter
3
<
Len3
-
1
)
{
bool
forward_dim3
=
(
iter2
%
2
==
0
);
if
(
forward_dim3
)
{
move_dynamic_tensor_coordinate
(
...
...
@@ -472,12 +251,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim3 iteration direction
forward_dim3
=
!
forward_dim3
;
// move dim1 iterator
if
(
i2
<
Len2
-
1
)
if
(
i
ter
2
<
Len2
-
1
)
{
bool
forward_dim2
=
(
iter1
%
2
==
0
);
if
(
forward_dim2
)
{
move_dynamic_tensor_coordinate
(
...
...
@@ -495,12 +273,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim2 iteration direction
forward_dim2
=
!
forward_dim2
;
// move dim1 iterator
if
(
i1
<
Len1
-
1
)
if
(
i
ter
1
<
Len1
-
1
)
{
bool
forward_dim1
=
(
iter0
%
2
==
0
);
if
(
forward_dim1
)
{
move_dynamic_tensor_coordinate
(
...
...
@@ -518,59 +295,132 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim1 iteration direction
forward_dim1
=
!
forward_dim1
;
// move dim0 iterator
if
(
i0
<
Len0
-
1
)
// move dim0 iterator:
if
(
iter0
<
Len0
-
1
)
{
if
(
forward_dim0
)
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0_0_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0_0_0
);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_m1_0_0_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_m1_0_0_0
);
}
// move forward in dim0
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0_0_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0_0_0
);
}
}
}
// move src and dst coordinate back to their origins
if
constexpr
(
MoveBackSrcCoord
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_
step_back
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
GetCoordinate
Step
Back
());
const
auto
src_
back_step
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
GetCoordinateBack
Step
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_
step_back
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_
back_step
);
}
if
constexpr
(
MoveBackDstCoord
)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_
step_back
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
GetCoordinate
Step
Back
());
const
auto
dst_
back_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
GetCoordinateBack
Step
());
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_
step_back
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_
back_step
);
}
}
__device__
static
constexpr
auto
GetCoordinate
Step
Back
()
__device__
static
constexpr
auto
GetCoordinateBack
Step
()
{
MultiIndex
<
nDim
>
step_back
;
MultiIndex
<
nDim
>
back_step
;
step_back
(
Number
<
0
>
{})
=
1
-
SliceLengths
{}[
0
];
back_step
(
Number
<
0
>
{})
=
1
-
SliceLengths
{}[
0
];
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
step_back
(
i
)
=
(
SliceLengths
{}[
i
-
Number
<
1
>
{}]
%
2
==
0
)
?
0
:
(
1
-
SliceLengths
{}[
i
]);
back_step
(
i
)
=
(
SliceLengths
{}[
i
-
Number
<
1
>
{}]
%
2
==
0
)
?
0
:
(
1
-
SliceLengths
{}[
i
]);
});
return
step_back
;
return
back_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
src_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
src_slice_origin_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_slice_origin_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
dst_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
dst_slice_origin_step_idx
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_slice_origin_step
);
}
private:
SrcCoord
src_slice_origin_
;
DstCoord
dst_slice_origin_
;
};
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
template
<
typename
SliceLengths
,
InMemoryDataOperation
DstInMemOp
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct
ThreadwiseDynamicTensorSliceTransfer_v3
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_dynamic_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_dynamic_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v3
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
:
src_slice_origin_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_slice_origin_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
static_assert
(
SrcAddressSpace
==
AddressSpace
::
Global
or
SrcAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
DstAddressSpace
==
AddressSpace
::
Global
or
DstAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
}
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v3
()
:
ThreadwiseDynamicTensorSliceTransfer_v3
(
SrcDesc
{},
make_zero_multi_index
<
nDim
>
(),
DstDesc
{},
make_zero_multi_index
<
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -583,15 +433,188 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
dst_slice_origin_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
)
{
static_assert
(
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
()
==
2
,
"wrong! hardcoded for 2D tensor"
);
// hardcoded for 2D
// TODO implemente N-D
if
constexpr
(
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
()
==
2
)
{
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const
auto
src_step_0_p1
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
make_multi_index
(
0
,
1
));
const
auto
src_step_0_m1
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
make_multi_index
(
0
,
-
1
));
const
auto
src_step_p1_0
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
make_multi_index
(
1
,
0
));
const
auto
src_step_m1_0
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
make_multi_index
(
-
1
,
0
));
constexpr
index_t
Len0
=
SliceLengths
{}[
0
];
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
static_for
<
0
,
Len0
,
1
>
{}([
&
](
auto
iter0
)
{
static_for
<
0
,
Len1
,
1
>
{}([
&
](
auto
iter1
)
{
// step direction
constexpr
bool
forward_dim1
=
(
iter0
.
value
%
2
==
0
);
constexpr
index_t
i0
=
iter0
.
value
;
constexpr
index_t
i1
=
forward_dim1
?
iter1
.
value
:
Len1
-
iter1
.
value
-
1
;
// do work
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
make_multi_index
(
i0
,
i1
));
// hardcoding for buffer_load
// TODO refactor transfer_data() to encapsulate this
static_assert
(
SrcAddressSpace
==
AddressSpace
::
Global
,
"wrong! hardcoded to use buffer_load, src must be global mem"
);
buffer_
(
Number
<
buffer_offset
>
{})
=
amd_buffer_load
<
SrcData
,
1
>
(
p_src
,
src_slice_origin_
.
GetOffset
(),
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
),
src_desc
.
GetElementSpaceSize
());
// move dim1 iterator
if
constexpr
(
iter1
.
value
<
Len1
-
1
)
{
if
constexpr
(
forward_dim1
)
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_0_p1
);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_0_m1
);
}
}
});
// move dim0 iterator
if
constexpr
(
iter0
.
value
<
Len0
-
1
)
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0
);
}
});
}
// move src and dst coordinate back to their origins
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_back_step
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
GetCoordinateBackStep
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_back_step
);
}
}
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
{
static_assert
(
remove_reference_t
<
DstDesc
>::
GetNumOfDimension
()
==
2
,
"wrong! hardcoded for 2D tensor"
);
// hardcoded for 2D
// TODO implement N-D
if
constexpr
(
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
()
==
2
)
{
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const
auto
dst_step_0_p1
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
make_multi_index
(
0
,
1
));
const
auto
dst_step_0_m1
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
make_multi_index
(
0
,
-
1
));
const
auto
dst_step_p1_0
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
make_multi_index
(
1
,
0
));
const
auto
dst_step_m1_0
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
make_multi_index
(
-
1
,
0
));
constexpr
index_t
Len0
=
SliceLengths
{}[
0
];
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
static_for
<
0
,
Len0
,
1
>
{}([
&
](
auto
iter0
)
{
static_for
<
0
,
Len1
,
1
>
{}([
&
](
auto
iter1
)
{
// step direction
constexpr
bool
forward_dim1
=
(
iter0
.
value
%
2
==
0
);
constexpr
index_t
i0
=
iter0
;
constexpr
index_t
i1
=
forward_dim1
?
iter1
.
value
:
Len1
-
iter1
.
value
-
1
;
// do work
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
make_multi_index
(
i0
,
i1
));
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert
(
DstAddressSpace
==
AddressSpace
::
Lds
&&
DstInMemOp
==
InMemoryDataOperation
::
Set
,
"wrong! hardcoded for ds_write"
);
p_dst
[
dst_slice_origin_
.
GetOffset
()]
=
buffer_
[
Number
<
buffer_offset
>
{}];
// move dim1 iterator
if
constexpr
(
iter1
.
value
<
Len1
-
1
)
{
if
constexpr
(
forward_dim1
)
{
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_0_p1
);
}
else
{
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_0_m1
);
}
}
});
// move dim0 iterator
if
constexpr
(
iter0
.
value
<
Len0
-
1
)
{
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0
);
}
});
}
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_back_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
GetCoordinateBackStep
());
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_back_step
);
}
}
__device__
static
constexpr
auto
GetCoordinateBackStep
()
{
MultiIndex
<
nDim
>
back_step
;
back_step
(
Number
<
0
>
{})
=
1
-
SliceLengths
{}[
0
];
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
back_step
(
i
)
=
(
SliceLengths
{}[
i
-
Number
<
1
>
{}]
%
2
==
0
)
?
0
:
(
1
-
SliceLengths
{}[
i
]);
});
return
back_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
src_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
src_slice_origin_step_idx
);
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateBackStep
();
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_slice_origin_step
);
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...
...
@@ -599,13 +622,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
const
Index
&
dst_slice_origin_step_idx
)
{
// is it OK to construct a new step every time?
const
auto
dst_slice_origin_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
dst_slice_origin_step_idx
);
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateBackStep
();
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_slice_origin_step
);
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
adjusted_step
);
}
private:
static
constexpr
auto
buffer_desc_
=
make_dynamic_naive_tensor_descriptor_packed
<
nDim
>
(
to_multi_index
(
SliceLengths
{}));
static
constexpr
index_t
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticallyIndexedArray
<
SrcData
,
buffer_size_
>
buffer_
;
SrcCoord
src_slice_origin_
;
DstCoord
dst_slice_origin_
;
};
...
...
driver/include/conv_common.hpp
View file @
760a234f
...
...
@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
}
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
constexpr
std
::
size_t
calculate_convolution_flops
(
InDesc
,
WeiDesc
,
OutDesc
)
constexpr
std
::
size_t
calculate_convolution_flops
(
const
InDesc
&
in_desc
,
const
WeiDesc
&
wei_desc
,
const
OutDesc
&
out_desc
)
{
using
namespace
ck
;
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
expr
index_t
N
=
out_desc
.
GetLength
(
I0
);
const
expr
index_t
K
=
out_desc
.
GetLength
(
I1
);
const
expr
index_t
Ho
=
out_desc
.
GetLength
(
I2
);
const
expr
index_t
Wo
=
out_desc
.
GetLength
(
I3
);
const
index_t
N
=
out_desc
.
GetLength
(
I0
);
const
index_t
K
=
out_desc
.
GetLength
(
I1
);
const
index_t
Ho
=
out_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_desc
.
GetLength
(
I3
);
const
expr
index_t
C
=
wei_desc
.
GetLength
(
I1
);
const
expr
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
const
expr
index_t
X
=
wei_desc
.
GetLength
(
I3
);
const
index_t
C
=
wei_desc
.
GetLength
(
I1
);
const
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_desc
.
GetLength
(
I3
);
return
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
}
...
...
driver/src/conv_driver.cpp
View file @
760a234f
...
...
@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif
0
#elif
1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment