Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8b5e63ed
Commit
8b5e63ed
authored
Mar 22, 2021
by
Jing Zhang
Browse files
mock up
parent
95a5af02
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
98 deletions
+109
-98
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+21
-19
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+31
-30
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+39
-35
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+18
-14
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
8b5e63ed
...
@@ -10,8 +10,9 @@
...
@@ -10,8 +10,9 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
Float
,
typename
FloatAB
,
typename
AccFloat
,
typename
FloatAcc
,
typename
FloatC
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
HoPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
WoPerBlock
,
...
@@ -42,9 +43,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -42,9 +43,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
AB
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
const
Float
AB
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
Float
C
*
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -166,8 +167,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -166,8 +167,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
// GEMM
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v3
<
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v3
<
BlockSize
,
BlockSize
,
Float
,
FloatAB
,
AccFloat
,
FloatAcc
,
FloatC
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
...
@@ -227,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -227,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
auto
kernel
=
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
out_gemmm_n_ho_wo_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
Float
C
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
...
@@ -254,11 +256,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -254,11 +256,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
auto
kernel
=
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
out_gemmm_n_ho_wo_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
Float
C
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
...
@@ -281,11 +283,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -281,11 +283,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
auto
kernel
=
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
out_gemmm_n_ho_wo_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
Float
C
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
...
@@ -308,11 +310,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
...
@@ -308,11 +310,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
auto
kernel
=
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
const
Float
*
,
const
Float
AB
*
,
decltype
(
out_gemmm_n_ho_wo_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
Float
C
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
8b5e63ed
...
@@ -12,8 +12,9 @@
...
@@ -12,8 +12,9 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
Float
,
typename
FloatAB
,
typename
AccFloat
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
BGlobalDesc
,
...
@@ -64,17 +65,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -64,17 +65,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_e_k_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_e_k_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
a_block_space_size
*
sizeof
(
Float
);
return
a_block_space_size
*
sizeof
(
Float
AB
);
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
AB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
AB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
C
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
Float
AB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
...
@@ -177,8 +178,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -177,8 +178,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
Float
,
Float
AB
,
Float
,
Float
AB
,
decltype
(
a_e_k_global_desc
),
decltype
(
a_e_k_global_desc
),
decltype
(
a_e_k_desc
),
decltype
(
a_e_k_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -203,8 +204,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -203,8 +204,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
AB
,
Float
,
Float
AB
,
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
...
@@ -218,10 +219,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -218,10 +219,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
true
>
(
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
Float
*
p_a_block
=
p_shared_block
;
Float
AB
*
p_a_block
=
p_shared_block
;
// register allocation for output
// register allocation for output
Acc
Float
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()];
Float
Acc
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()];
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero_v3
(
c_k_n_ho_wo_thread_desc
,
p_c_thread
);
threadwise_matrix_set_zero_v3
(
c_k_n_ho_wo_thread_desc
,
p_c_thread
);
...
@@ -240,9 +241,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -240,9 +241,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_thread_space_size
=
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
();
constexpr
auto
b_thread_space_size
=
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
();
Float
p_b_thread
[
b_thread_space_size
*
2
];
Float
AB
p_b_thread
[
b_thread_space_size
*
2
];
Float
*
p_b_thread_double
=
p_b_thread
;
Float
AB
*
p_b_thread_double
=
p_b_thread
;
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
...
@@ -265,8 +266,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -265,8 +266,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1
#if 1
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
Float
*
p_b_thread_even
=
p_b_thread_double
;
Float
AB
*
p_b_thread_even
=
p_b_thread_double
;
Float
*
p_b_thread_odd
=
p_b_thread_double
+
b_thread_space_size
;
Float
AB
*
p_b_thread_odd
=
p_b_thread_double
+
b_thread_space_size
;
// LDS double buffer: main body
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
// use Do-While loop instead of For loop to simplify control flow
...
@@ -359,8 +360,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -359,8 +360,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
Acc
Float
,
Float
Acc
,
Float
,
Float
C
,
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
decltype
(
c_k_n_ho_wo_global_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>
,
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>
,
...
@@ -388,17 +389,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -388,17 +389,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptor by reference
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
AB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
AB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
C
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
AB
);
__shared__
Float
p_shared_block
[
shared_block_size
];
__shared__
Float
AB
p_shared_block
[
shared_block_size
];
Run
(
a_e_k_global_desc
,
Run
(
a_e_k_global_desc
,
p_a_global
,
p_a_global
,
...
@@ -414,11 +415,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -414,11 +415,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by their pointers
// pass tensor descriptors by their pointers
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_e_k_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
AB
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
*
p_b_e_n_ho_wo_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
AB
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
*
p_c_k_n_ho_wo_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
C
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
...
@@ -439,11 +440,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -439,11 +440,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by void*
// pass tensor descriptors by void*
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_e_k_global_desc
,
__device__
void
Run
(
const
void
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
AB
*
__restrict__
p_a_global
,
const
void
*
p_b_e_n_ho_wo_global_desc
,
const
void
*
p_b_e_n_ho_wo_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
AB
*
__restrict__
p_b_global
,
const
void
*
p_c_k_n_ho_wo_global_desc
,
const
void
*
p_c_k_n_ho_wo_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
C
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
8b5e63ed
...
@@ -3,7 +3,10 @@
...
@@ -3,7 +3,10 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
class
TAcc
,
class
TOut
,
class
InDesc
,
class
InDesc
,
class
WeiDesc
,
class
WeiDesc
,
class
OutDesc
,
class
OutDesc
,
...
@@ -11,33 +14,31 @@ template <class T,
...
@@ -11,33 +14,31 @@ template <class T,
class
ConvDilations
,
class
ConvDilations
,
class
InLeftPads
,
class
InLeftPads
,
class
InRightPads
>
class
InRightPads
>
void
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
InDesc
,
void
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
const
Tensor
<
T
>&
in_nchw
,
InDesc
,
WeiDesc
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
T
>&
wei_kcyx
,
WeiDesc
,
OutDesc
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
Tensor
<
T
>&
out_nkhw
,
OutDesc
,
ConvStrides
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ConvDilations
,
ConvStrides
,
InLeftPads
,
ConvDilations
,
InRightPads
,
InLeftPads
,
ck
::
index_t
nrepeat
)
InRightPads
,
ck
::
index_t
nrepeat
)
{
{
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
<<
std
::
endl
;
using
namespace
ck
;
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<<
std
::
endl
;
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_n
chw
_device_buf
.
ToDevice
(
in_n
chw
.
mData
.
data
());
in_n
_c_hi_wi
_device_buf
.
ToDevice
(
in_n
_c_hi_wi
.
mData
.
data
());
wei_k
cy
x_device_buf
.
ToDevice
(
wei_k
cy
x
.
mData
.
data
());
wei_k
_c_y_
x_device_buf
.
ToDevice
(
wei_k
_c_y_
x
.
mData
.
data
());
out_n
khw
_device_buf
.
ToDevice
(
out_n
khw
.
mData
.
data
());
out_n
_k_ho_wo
_device_buf
.
ToDevice
(
out_n
_k_ho_wo
.
mData
.
data
());
#if 0
#if 0
// run-time variables
// run-time variables
...
@@ -70,18 +71,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -70,18 +71,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
16
;
constexpr
index_t
WoPerBlock
=
16
;
constexpr
index_t
WoPerBlock
=
16
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
4
;
constexpr
index_t
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
1
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
4
,
16
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
9
,
4
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
...
@@ -93,8 +94,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -93,8 +94,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
<
BlockSize
,
BlockSize
,
TDevice
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TDevice
,
TAcc
,
TOut
,
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
...
@@ -117,9 +119,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -117,9 +119,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
static_cast
<
TDevice
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
TDevice
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()));
out_n
khw
_device_buf
.
FromDevice
(
out_n
khw
.
mData
.
data
());
out_n
_k_ho_wo
_device_buf
.
FromDevice
(
out_n
_k_ho_wo
.
mData
.
data
());
}
}
driver/src/conv_driver.cpp
View file @
8b5e63ed
...
@@ -80,10 +80,10 @@ int main(int argc, char* argv[])
...
@@ -80,10 +80,10 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
C
=
1
;
constexpr
index_t
HI
=
1024
;
constexpr
index_t
HI
=
1024
;
constexpr
index_t
WI
=
2048
;
constexpr
index_t
WI
=
2048
;
constexpr
index_t
K
=
16
;
constexpr
index_t
K
=
4
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
using in_data_t = float;
constexpr index_t in_vector_size = 1;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using acc_data_t = float;
...
@@ -754,17 +754,21 @@ int main(int argc, char* argv[])
...
@@ -754,17 +754,21 @@ int main(int argc, char* argv[])
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
<
in_data_t
,
in_nchw
,
in_vector_size
,
wei_kcyx_desc
,
acc_data_t
,
wei_kcyx
,
out_data_t
>
(
out_nkhw_desc
,
in_nchw_desc
,
out_nkhw_device
,
in_nchw
,
ConvStrides
{},
wei_kcyx_desc
,
ConvDilations
{},
wei_kcyx
,
LeftPads
{},
out_nkhw_desc
,
RightPads
{},
out_nkhw_device
,
nrepeat
);
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
#endif
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment