Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6c2a3a95
Commit
6c2a3a95
authored
Mar 18, 2021
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/dynamic_tensor_descriptor' into dynamic_tensor_descriptor_v5r1
parents
f744524e
6bf99b9e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
134 additions
and
37 deletions
+134
-37
composable_kernel/include/gridwise_operation_wrapper.hpp
composable_kernel/include/gridwise_operation_wrapper.hpp
+2
-2
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+28
-27
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+2
-2
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+67
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+33
-3
No files found.
composable_kernel/include/gridwise_operation_wrapper.hpp
View file @
6c2a3a95
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
template
<
typename
GridwiseOp
,
typename
...
Xs
>
template
<
typename
GridwiseOp
,
typename
...
Xs
>
__global__
void
__global__
void
#if
1
#if
0
__launch_bounds__(256, 2)
__launch_bounds__(256, 2)
#endif
#endif
run_gridwise_operation
(
Xs
...
xs
)
run_gridwise_operation
(
Xs
...
xs
)
{
{
GridwiseOp
{}.
Run
(
xs
...);
GridwiseOp
{}.
Run
(
xs
...);
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
6c2a3a95
...
@@ -91,7 +91,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -91,7 +91,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
auto
W
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I3
);
const
auto
W
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I3
);
// divide block work by [M, N]
// divide block work by [M, N]
#if
1
#if
0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
...
@@ -646,32 +646,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -646,32 +646,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
auto
a_blockwise_copy
=
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
CYX
,
K
>
,
Sequence
<
CYX
,
K
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
Float
,
Float
,
Float
,
Float
,
decltype
(
a_cyx_k_global_desc
),
decltype
(
a_cyx_k_global_desc
),
decltype
(
a_cyx_k_desc
),
decltype
(
a_cyx_k_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_cyx_k_global_desc
,
true
>
(
make_multi_index
(
0
,
k_block_data_on_global
),
a_cyx_k_global_desc
,
a_cyx_k_desc
,
make_multi_index
(
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
0
));
a_cyx_k_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
b_cyx_n_h_w_thread_desc
=
constexpr
auto
b_cyx_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
6c2a3a95
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
#endif
#endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
0
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
1
#endif
#endif
#ifndef CK_USE_AMD_V_FMAC_F32
#ifndef CK_USE_AMD_V_FMAC_F32
...
@@ -74,7 +74,7 @@
...
@@ -74,7 +74,7 @@
// experimental implementation
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
0
#endif
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
6c2a3a95
...
@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#endif
#if
1
#if
0
// cdata = 16, BlockSize = 64, 16x64x4
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t BlockSize = 64;
...
@@ -101,6 +101,72 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -101,6 +101,72 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif
0
// cdata = 32, BlockSize 64, 16x128x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
6c2a3a95
...
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
#if
1
#if
0
// cdata = 16, BlockSize = 64, 16x64x4
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t BlockSize = 64;
...
@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 64, 16x256x4
// cdata = 64, BlockSize = 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
...
driver/src/conv_driver.cpp
View file @
6c2a3a95
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -49,8 +48,8 @@ int main(int argc, char* argv[])
...
@@ -49,8 +48,8 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
...
@@ -65,6 +64,20 @@ int main(int argc, char* argv[])
...
@@ -65,6 +64,20 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
C
=
4
;
...
@@ -724,6 +737,23 @@ int main(int argc, char* argv[])
...
@@ -724,6 +737,23 @@ int main(int argc, char* argv[])
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
acc_data_t
,
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment