Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
510b3a21
Commit
510b3a21
authored
May 12, 2021
by
Chao Liu
Browse files
move AddressSpace info from copy operator into DynamicBuffer and StaticBuffer
parent
aac345ab
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
98 additions
and
112 deletions
+98
-112
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+0
-4
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+10
-12
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+1
-3
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+16
-25
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+8
-15
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+14
-23
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-1
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-0
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+10
-26
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+33
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
510b3a21
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -153,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -153,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
DstScalarStrideInVector
,
SrcAddressSpace
,
DstAddressSpace
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadTransferDstResetCoordinateAfterRun
>
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
510b3a21
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C:
// 3. C:
// 1. CThreadDesc is known at compile-time
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
510b3a21
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
ThreadGemmADataPerRead_K
,
ThreadGemmADataPerRead_K
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatB
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
510b3a21
...
@@ -35,8 +35,7 @@ __global__ void
...
@@ -35,8 +35,7 @@ __global__ void
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
::
Run
(
GridwiseGemm
::
Run
(
p_a_global
,
p_a_global
,
p_b_global
,
p_b_global
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
...
@@ -85,8 +84,7 @@ __global__ void
...
@@ -85,8 +84,7 @@ __global__ void
const
auto
c_block_cluster_desc
=
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
::
Run
(
GridwiseGemm
::
Run
(
p_a_global
,
p_a_global
,
p_b_global
,
p_b_global
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
...
@@ -171,8 +169,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -171,8 +169,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
AGlobalDesc
&
a_k_m_global_desc
,
...
@@ -188,9 +185,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -188,9 +185,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
(
p_a_global
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
);
const
auto
b_global_buf
=
make_dynamic_buffer
(
p_b_global
);
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
);
auto
c_global_buf
=
make_dynamic_buffer
(
p_c_global
);
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
...
@@ -241,8 +238,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -241,8 +238,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -270,8 +265,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -270,8 +265,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -346,8 +339,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -346,8 +339,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
make_static_buffer
<
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
...
@@ -368,11 +361,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -368,11 +361,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_double
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
);
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_double
);
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_double
+
a_block_space_size
);
auto
a_block_odd_buf
=
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_double
+
b_block_space_size
);
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
+
a_block_space_size
);
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
+
b_block_space_size
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
...
@@ -497,8 +492,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -497,8 +492,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
...
@@ -517,8 +510,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -517,8 +510,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
AGlobalDesc
&
a_k_m_global_desc
,
...
@@ -532,8 +524,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -532,8 +524,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
Run
(
p_a_global
,
p_a_global
,
p_b_global
,
p_b_global
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
510b3a21
...
@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
(
p_a_global
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
);
const
auto
b_global_buf
=
make_dynamic_buffer
(
p_b_global
);
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
);
auto
c_global_buf
=
make_dynamic_buffer
(
p_c_global
);
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
);
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
...
@@ -196,8 +196,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -196,8 +196,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
,
ABlockTransferDstScalarPerVector_K
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -220,19 +218,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -220,19 +218,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
>
(
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
FloatAB
*
p_a_block
=
p_shared_block
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_shared_block
);
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
// register allocation for output
// register allocation for output
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
...
@@ -254,8 +249,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -254,8 +249,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_even_buf
,
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
...
@@ -362,8 +357,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -362,8 +357,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
510b3a21
...
@@ -54,8 +54,6 @@ template <typename SrcData,
...
@@ -54,8 +54,6 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
DstResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
,
...
@@ -211,8 +209,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -211,8 +209,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_coord_
);
dst_desc
,
dst_slice_origin_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Vgpr
&&
if
constexpr
(
Src
Buffer
::
Get
AddressSpace
()
==
AddressSpace
::
Vgpr
&&
DstAddressSpace
==
AddressSpace
::
Global
)
Dst
Buffer
::
Get
AddressSpace
()
==
AddressSpace
::
Global
)
{
{
#if CK_USE_AMD_BUFFER_ADDRESSING
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2
<
DstData
,
DstScalarPerVector
>
(
amd_buffer_store_v2
<
DstData
,
DstScalarPerVector
>
(
...
@@ -403,8 +401,6 @@ template <typename SrcData,
...
@@ -403,8 +401,6 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
SrcResetCoordinateAfterRun
,
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
...
@@ -541,8 +537,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -541,8 +537,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}();
}();
// copy data
// copy data
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for vgpr dst"
);
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
using
src_vector_t
=
using
src_vector_t
=
...
@@ -551,7 +545,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -551,7 +545,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_desc
,
src_slice_origin_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
if
constexpr
(
Src
Buffer
::
Get
AddressSpace
()
==
AddressSpace
::
Global
)
{
{
#if CK_USE_AMD_BUFFER_ADDRESSING
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
...
@@ -748,8 +742,6 @@ template <typename SliceLengths,
...
@@ -748,8 +742,6 @@ template <typename SliceLengths,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
// save addr computation
...
@@ -774,13 +766,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -774,13 +766,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
{
static_assert
(
SrcAddressSpace
==
AddressSpace
::
Global
or
SrcAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
DstAddressSpace
==
AddressSpace
::
Global
or
DstAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
// TODO: fix this
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
"wrong! current implementation assume SrcData and DstData are same type"
);
"wrong! current implementation assume SrcData and DstData are same type"
);
...
@@ -801,6 +786,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -801,6 +786,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
...
@@ -897,7 +886,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -897,7 +886,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_desc
,
src_slice_origin_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
if
constexpr
(
Src
Buffer
::
Get
AddressSpace
()
==
AddressSpace
::
Global
)
{
{
#if CK_USE_AMD_BUFFER_ADDRESSING
#if CK_USE_AMD_BUFFER_ADDRESSING
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
...
@@ -983,6 +972,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -983,6 +972,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
...
@@ -1078,7 +1071,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1078,7 +1071,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// copy data
// copy data
// hardcoding for ds_write
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
// TODO refactor transfer_data() to encapsulate this
static_assert
(
DstAddressSpace
==
AddressSpace
::
Lds
&&
static_assert
(
Dst
Buffer
::
Get
AddressSpace
()
==
AddressSpace
::
Lds
&&
DstInMemOp
==
InMemoryDataOperation
::
Set
,
DstInMemOp
==
InMemoryDataOperation
::
Set
,
"wrong! hardcoded for ds_write"
);
"wrong! hardcoded for ds_write"
);
...
@@ -1356,7 +1349,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1356,7 +1349,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
SrcCoord
src_slice_origin_coord_
;
SrcCoord
src_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
...
@@ -1384,8 +1377,6 @@ template <
...
@@ -1384,8 +1377,6 @@ template <
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
...
...
composable_kernel/include/utility/common_header.hpp
View file @
510b3a21
...
@@ -8,7 +8,8 @@
...
@@ -8,7 +8,8 @@
#include "container_element_picker.hpp"
#include "container_element_picker.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "float_type.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional3.hpp"
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
510b3a21
...
@@ -159,6 +159,7 @@ enum AddressSpace
...
@@ -159,6 +159,7 @@ enum AddressSpace
Generic,
Generic,
Global,
Global,
Lds,
Lds,
Sgpr,
Vgpr
Vgpr
};
};
...
...
composable_kernel/include/utility/buffer.hpp
→
composable_kernel/include/utility/
dynamic_
buffer.hpp
View file @
510b3a21
#ifndef CK_BUFFER_HPP
#ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
T
,
index_t
N
>
template
<
AddressSpace
BufferAddressSpace
,
typename
T
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
template
<
typename
T
>
struct
DynamicBuffer
struct
DynamicBuffer
{
{
using
type
=
T
;
using
type
=
T
;
...
@@ -33,6 +12,11 @@ struct DynamicBuffer
...
@@ -33,6 +12,11 @@ struct DynamicBuffer
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
...
@@ -91,10 +75,10 @@ struct DynamicBuffer
...
@@ -91,10 +75,10 @@ struct DynamicBuffer
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
};
template
<
typename
T
>
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
{
return
DynamicBuffer
<
T
>
{
p
};
return
DynamicBuffer
<
BufferAddressSpace
,
T
>
{
p
};
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/static_buffer.hpp
0 → 100644
View file @
510b3a21
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
AddressSpace
BufferAddressSpace
,
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
>
{};
}
}
// namespace ck
#endif
driver/src/conv_driver.cpp
View file @
510b3a21
...
@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
...
@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
HI
=
540
;
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
using in_data_t = float;
constexpr index_t in_vector_size = 1;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using acc_data_t = float;
...
@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
...
@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
in_vector_size
,
acc_data_t
,
acc_data_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment