Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
03aa52bc
"src/include/ConstantMergedTensorDescriptor.hpp" did not exist on "acd7082fe109aa4228dfca652e87cab96bc6837f"
Commit
03aa52bc
authored
Apr 09, 2021
by
Jing Zhang
Browse files
fixed
parent
db4afa69
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
100 deletions
+125
-100
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
...tion_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
+17
-16
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+64
-52
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+30
-26
driver/include/host_conv.hpp
driver/include/host_conv.hpp
+10
-3
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
View file @
03aa52bc
...
@@ -39,7 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -39,7 +39,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Add
...
>&
add_n_k0_
2
ho
_2wo
_k1_global_desc
,
const
DynamicTensorDescriptor
<
Add
...
>&
add_n_k0_ho
x2_wox2
_k1_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
...
@@ -66,6 +66,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -66,6 +66,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
Hox2
=
Ho
*
2
;
const
auto
Wox2
=
Wo
*
2
;
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
K
=
wei_k_c_y_x_global_desc
.
GetLength
(
I0
);
const
auto
K
=
wei_k_c_y_x_global_desc
.
GetLength
(
I0
);
...
@@ -146,18 +149,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -146,18 +149,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// add tensor
// add tensor
const
auto
add_k_n_
2
hop_
2
wop_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
add_k_n_hop
x2
_wop
x2
_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
2
*
Ho
,
2
*
Wo
,
K1
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wo
x2
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pad_transform
(
2
*
Ho
,
0
,
AddRightPadH
),
make_pad_transform
(
Ho
x2
,
0
,
AddRightPadH
),
make_pad_transform
(
2
*
Wo
,
0
,
AddRightPadW
)),
make_pad_transform
(
Wo
x2
,
0
,
AddRightPadW
)),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
E
=
C
*
Y
*
X
;
const
auto
E
=
C
*
Y
*
X
;
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
...
@@ -209,7 +210,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -209,7 +210,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
add_k_n_
2
hop_
2
wop_global_desc
),
decltype
(
add_k_n_hop
x2
_wop
x2
_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
...
@@ -269,7 +270,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -269,7 +270,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
add_k_n_
2
hop_
2
wop_global_desc
),
decltype
(
add_k_n_hop
x2
_wop
x2
_global_desc
),
const
FloatC
*
,
const
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
FloatC
*
,
FloatC
*
,
...
@@ -285,7 +286,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -285,7 +286,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
add_k_n_
2
hop_
2
wop_global_desc
,
add_k_n_hop
x2
_wop
x2
_global_desc
,
p_d_global
,
p_d_global
,
out_k_n_hop_wop_global_desc
,
out_k_n_hop_wop_global_desc
,
p_out_global
,
p_out_global
,
...
@@ -300,7 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -300,7 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
add_k_n_
2
hop_
2
wop_global_desc
),
decltype
(
add_k_n_hop
x2
_wop
x2
_global_desc
),
const
FloatC
*
,
const
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
FloatC
*
,
FloatC
*
,
...
@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
add_k_n_
2
hop_
2
wop_global_desc
,
add_k_n_hop
x2
_wop
x2
_global_desc
,
p_d_global
,
p_d_global
,
out_k_n_hop_wop_global_desc
,
out_k_n_hop_wop_global_desc
,
p_out_global
,
p_out_global
,
...
@@ -331,7 +332,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -331,7 +332,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
add_k_n_
2
hop_
2
wop_global_desc
),
decltype
(
add_k_n_hop
x2
_wop
x2
_global_desc
),
const
FloatC
*
,
const
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
FloatC
*
,
FloatC
*
,
...
@@ -347,7 +348,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -347,7 +348,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
add_k_n_
2
hop_
2
wop_global_desc
,
add_k_n_hop
x2
_wop
x2
_global_desc
,
p_d_global
,
p_d_global
,
out_k_n_hop_wop_global_desc
,
out_k_n_hop_wop_global_desc
,
p_out_global
,
p_out_global
,
...
@@ -362,7 +363,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -362,7 +363,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
add_k_n_
2
hop_
2
wop_global_desc
),
decltype
(
add_k_n_hop
x2
_wop
x2
_global_desc
),
const
FloatC
*
,
const
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
decltype
(
out_k_n_hop_wop_global_desc
),
FloatC
*
,
FloatC
*
,
...
@@ -378,7 +379,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -378,7 +379,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
add_k_n_
2
hop_
2
wop_global_desc
,
add_k_n_hop
x2
_wop
x2
_global_desc
,
p_d_global
,
p_d_global
,
out_k_n_hop_wop_global_desc
,
out_k_n_hop_wop_global_desc
,
p_out_global
,
p_out_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
03aa52bc
...
@@ -74,7 +74,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -74,7 +74,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
DGlobalDesc
&
d_k_n_
2
ho
_2wo
_global_desc
,
const
DGlobalDesc
&
d_k_n_ho
x2_wox2
_global_desc
,
const
FloatC
*
__restrict__
p_d_global
,
const
FloatC
*
__restrict__
p_d_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
...
@@ -89,7 +89,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -89,7 +89,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
// const auto E = a_e_k_global_desc.GetLength(I0);
const
auto
K
=
a_e_k_global_desc
.
GetLength
(
I1
);
const
auto
K
=
a_e_k_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I1
);
...
@@ -148,10 +147,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -148,10 +147,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
constexpr
auto
d_k_n_2ho_2wo_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
2
*
HoPerThread
>
{},
Number
<
2
*
WoPerThread
>
{}));
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
a_e_k_block_desc
),
...
@@ -358,29 +353,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -358,29 +353,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_c_thread
);
p_c_thread
);
}
}
#endif
#endif
// output: register to global memory
{
constexpr
auto
HoPerThreadx2
=
HoPerThread
*
2
;
constexpr
auto
WoPerThreadx2
=
WoPerThread
*
2
;
#if 1
const
index_t
hox2_block_data_on_global
=
ho_block_work_id
*
HoPerBlock
*
2
;
FloatC
p_d_thread
[
d_k_n_2ho_2wo_thread_desc
.
GetElementSpaceSize
()];
const
index_t
wox2_block_data_on_global
=
wo_block_work_id
*
WoPerBlock
*
2
;
threadwise_matrix_set_zero_v3
(
d_k_n_2ho_2wo_thread_desc
,
p_d_thread
);
const
index_t
ho2_thread_data_on_global
=
const
index_t
ho
x
2_thread_data_on_global
=
ho
_block_data_on_global
+
ho_thread_id
*
HoPerThread
*
2
;
hox2
_block_data_on_global
+
ho_thread_id
*
HoPerThread
x
2
;
const
index_t
wo2_thread_data_on_global
=
const
index_t
wo
x
2_thread_data_on_global
=
wo
_block_data_on_global
+
wo_thread_id
*
WoPerThread
*
2
;
wox2
_block_data_on_global
+
wo_thread_id
*
WoPerThread
x
2
;
const
index_t
k_thread_data_on_global
=
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
{
constexpr
auto
d_k_n_hox2_wox2_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThreadx2
>
{},
Number
<
WoPerThreadx2
>
{}));
constexpr
auto
c
_k_n_ho_wo
_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{}
;
FloatC
p_d_thread
[
d
_k_n_ho
x2
_wo
x2_thread_desc
.
GetElementSpaceSize
()]
;
constexpr
auto
c_k_n_ho_wo_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
#if 1
ThreadwiseDynamicTensorSliceTransfer_v2
<
ThreadwiseDynamicTensorSliceTransfer_v2
<
FloatC
,
FloatC
,
FloatC
,
FloatC
,
decltype
(
d_k_n_
2
ho
_2wo
_global_desc
),
decltype
(
d_k_n_ho
x2_wox2
_global_desc
),
decltype
(
d_k_n_
2
ho
_2wo
_thread_desc
),
decltype
(
d_k_n_ho
x2_wox2
_thread_desc
),
Sequence
<
KPerThread
,
1
,
2
*
HoPerThread
,
2
*
WoPerThread
>
,
Sequence
<
KPerThread
,
1
,
HoPerThread
x2
,
WoPerThread
x2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
...
@@ -388,36 +392,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -388,36 +392,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
>
(
true
>
(
d_k_n_hox2_wox2_global_desc
,
d_k_n_2ho_2wo_global_desc
,
make_multi_index
(
k_thread_data_on_global
,
make_multi_index
(
0
,
k_thread_data_on_global
,
0
,
ho2_thread_data_on_global
,
wo2_thread_data_on_global
))
hox2_thread_data_on_global
,
.
Run
(
d_k_n_2ho_2wo_global_desc
,
wox2_thread_data_on_global
))
.
Run
(
d_k_n_hox2_wox2_global_desc
,
p_d_global
,
p_d_global
,
d_k_n_
2
ho
_2wo
_thread_desc
,
d_k_n_ho
x2_wox2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_d_thread
,
p_d_thread
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
for
(
index_t
i
=
0
;
i
<
d_k_n_2ho_2wo_thread_desc
.
GetElementSpaceSize
();
i
++
)
{
p_d_thread
[
i
]
+=
p_c_thread
[
i
/
2
];
}
#endif
#endif
#if 1
#if 1
// output: register to global memory
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
{
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
for
(
index_t
h_i
=
0
;
h_i
<
HoPerThreadx2
;
++
h_i
)
constexpr
auto
c_k_n_ho_wo_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
{
for
(
index_t
w_i
=
0
;
w_i
<
WoPerThreadx2
;
++
w_i
)
{
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
CalculateOffset
(
make_tuple
(
k_i
,
0
,
h_i
,
w_i
))]
+=
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
CalculateOffset
(
make_tuple
(
k_i
,
0
,
h_i
/
2
,
w_i
/
2
))];
}
}
}
#endif
#if 1
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
FloatC
,
FloatC
,
decltype
(
d_k_n_
2
ho
_2wo
_thread_desc
),
decltype
(
d_k_n_ho
x2_wox2
_thread_desc
),
decltype
(
d_k_n_
2
ho
_2wo
_global_desc
),
decltype
(
d_k_n_ho
x2_wox2
_global_desc
),
Sequence
<
KPerThread
,
1
,
2
*
HoPerThread
,
2
*
WoPerThread
>
,
Sequence
<
KPerThread
,
1
,
HoPerThread
x2
,
WoPerThread
x2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
...
@@ -425,18 +436,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -425,18 +436,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace
::
Global
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
d_k_n_hox2_wox2_global_desc
,
d_k_n_2ho_2wo_global_desc
,
make_multi_index
(
k_thread_data_on_global
,
make_multi_index
(
0
,
k_thread_data_on_global
,
0
,
ho2_thread_data_on_global
,
wo2_thread_data_on_global
))
hox2_thread_data_on_global
,
.
Run
(
d_k_n_2ho_2wo_thread_desc
,
wox2_thread_data_on_global
))
.
Run
(
d_k_n_hox2_wox2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_d_thread
,
p_d_thread
,
d_k_n_
2
ho
_2wo
_global_desc
,
d_k_n_ho
x2_wox2
_global_desc
,
p_c_global
,
p_c_global
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
#endif
#endif
}
}
}
// pass tensor descriptor by reference
// pass tensor descriptor by reference
...
@@ -445,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -445,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
DGlobalDesc
&
d_k_n_
2
ho
_2wo
_global_desc
,
const
DGlobalDesc
&
d_k_n_ho
x2_wox2
_global_desc
,
const
FloatC
*
__restrict__
p_d_global
,
const
FloatC
*
__restrict__
p_d_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
...
@@ -460,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -460,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
p_b_global
,
d_k_n_
2
ho
_2wo
_global_desc
,
d_k_n_ho
x2_wox2
_global_desc
,
p_d_global
,
p_d_global
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
p_c_global
,
...
@@ -475,7 +487,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -475,7 +487,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
*
p_b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
DGlobalDesc
&
d_k_n_
2
ho
_2wo
_global_desc
,
const
DGlobalDesc
&
d_k_n_ho
x2_wox2
_global_desc
,
const
FloatC
*
__restrict__
p_d_global
,
const
FloatC
*
__restrict__
p_d_global
,
const
CGlobalDesc
*
p_c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
*
p_c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
...
@@ -490,7 +502,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -490,7 +502,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
p_b_global
,
d_k_n_
2
ho
_2wo
_global_desc
,
d_k_n_ho
x2_wox2
_global_desc
,
p_d_global
,
p_d_global
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
p_c_global
,
...
@@ -504,7 +516,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -504,7 +516,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
void
*
p_b_e_n_ho_wo_global_desc
,
const
void
*
p_b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
DGlobalDesc
&
d_k_n_
2
ho
_2wo
_global_desc
,
const
DGlobalDesc
&
d_k_n_ho
x2_wox2
_global_desc
,
const
FloatC
*
__restrict__
p_d_global
,
const
FloatC
*
__restrict__
p_d_global
,
const
void
*
p_c_k_n_ho_wo_global_desc
,
const
void
*
p_c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
...
@@ -521,14 +533,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -521,14 +533,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_a_global
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
p_b_global
,
d_k_n_
2
ho
_2wo
_global_desc
,
d_k_n_ho
x2_wox2
_global_desc
,
p_d_global
,
p_d_global
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
};
};
// namespace ck
}
// namespace ck
}
// namespace ck
#endif
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
03aa52bc
...
@@ -22,9 +22,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -22,9 +22,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
WeiDesc
,
WeiDesc
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
AddDesc
,
AddDesc
,
const
Tensor
<
TOut
>&
add_n_k_
2
ho
_2wo
,
const
Tensor
<
TOut
>&
add_n_k_ho
x2_wox2
,
OutDesc
,
OutDesc
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
Tensor
<
TOut
>&
out_n_k_ho
x2
_wo
x2
,
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
InLeftPads
,
InLeftPads
,
...
@@ -38,8 +38,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -38,8 +38,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k_2ho_2wo_device_buf
(
sizeof
(
TOut
)
*
add_n_k_2ho_2wo
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k_hox2_wox2_device_buf
(
sizeof
(
TOut
)
*
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
add_n_k_2ho_2wo
.
mDesc
.
GetElementSpace
());
add_n_k_hox2_wox2
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_hox2_wox2_device_buf
(
sizeof
(
TOut
)
*
add_n_k_hox2_wox2
.
mDesc
.
GetElementSpace
());
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -56,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -56,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
auto
Ho
=
OutDesc
::
GetLengths
()[
I2
];
constexpr
auto
Ho
=
OutDesc
::
GetLengths
()[
I2
];
constexpr
auto
Wo
=
OutDesc
::
GetLengths
()[
I3
];
constexpr
auto
Wo
=
OutDesc
::
GetLengths
()[
I3
];
constexpr
auto
Hox2
=
Ho
*
2
;
constexpr
auto
Wox2
=
Wo
*
2
;
constexpr
auto
Y
=
WeiDesc
::
GetLengths
()[
I2
];
constexpr
auto
Y
=
WeiDesc
::
GetLengths
()[
I2
];
constexpr
auto
X
=
WeiDesc
::
GetLengths
()[
I3
];
constexpr
auto
X
=
WeiDesc
::
GetLengths
()[
I3
];
...
@@ -71,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -71,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc =
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc =
const auto out_n_k_ho
x2
_wo
x2
_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_strides = to_multi_index(ConvStrides{});
...
@@ -86,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -86,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C0
,
Y
,
X
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C0
,
Y
,
X
));
const
auto
out_n_k0_ho_wo_k1_desc
=
const
auto
out_n_k0_ho_wo_k1_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
const
auto
add_n_k0_
2
ho
_2wo
_k1_desc
=
const
auto
add_n_k0_ho
x2_wox2
_k1_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
2
*
Ho
,
2
*
Wo
,
K1
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wo
x2
,
K1
));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
...
@@ -99,10 +104,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -99,10 +104,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C0
,
Hi
,
Wi
,
C1
>
{})));
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C0
,
Hi
,
Wi
,
C1
>
{})));
Tensor
<
TInWei
>
wei_k_c0_y_x_c1
(
make_HostTensorDescriptor
(
Tensor
<
TInWei
>
wei_k_c0_y_x_c1
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C0
,
Y
,
X
,
C1
>
{})));
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C0
,
Y
,
X
,
C1
>
{})));
Tensor
<
TOut
>
add_n_k0_
2
ho
_2wo
_k1
(
make_HostTensorDescriptor
(
Tensor
<
TOut
>
add_n_k0_ho
x2_wox2
_k1
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K0
,
2
*
Ho
,
2
*
Wo
,
K1
>
{})));
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K0
,
Hox2
,
Wo
x2
,
K1
>
{})));
Tensor
<
TOut
>
out_n_k0_ho_wo_k1
(
make_HostTensorDescriptor
(
Tensor
<
TOut
>
out_n_k0_ho
x2
_wo
x2
_k1
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K0
,
2
*
Ho
,
2
*
Wo
,
K1
>
{})));
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K0
,
Hox2
,
Wo
x2
,
K1
>
{})));
auto
f_nchw2nc0hwc1
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
auto
f_nchw2nc0hwc1
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
in_n_c0_hi_wi_c1
(
n
,
c
/
InWeiVectorSize
,
hi
,
wi
,
c
%
InWeiVectorSize
)
=
in_n_c0_hi_wi_c1
(
n
,
c
/
InWeiVectorSize
,
hi
,
wi
,
c
%
InWeiVectorSize
)
=
...
@@ -115,17 +120,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -115,17 +120,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
};
};
auto
f_nkhw_to_nk0hwk1
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nkhw_to_nk0hwk1
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
add_n_k0_
2
ho
_2wo
_k1
(
n
,
k
/
InWeiVectorSize
,
ho
,
wo
,
k
%
InWeiVectorSize
)
=
add_n_k0_ho
x2_wox2
_k1
(
n
,
k
/
InWeiVectorSize
,
ho
,
wo
,
k
%
InWeiVectorSize
)
=
add_n_k_
2
ho
_2wo
(
n
,
k
,
ho
,
wo
);
add_n_k_ho
x2_wox2
(
n
,
k
,
ho
,
wo
);
};
};
make_ParallelTensorFunctor
(
f_nchw2nc0hwc1
,
N
,
Hi
,
Wi
,
C
)();
make_ParallelTensorFunctor
(
f_nchw2nc0hwc1
,
N
,
Hi
,
Wi
,
C
)();
make_ParallelTensorFunctor
(
f_kcyx2kc0yxc1
,
K
,
Y
,
X
,
C
)();
make_ParallelTensorFunctor
(
f_kcyx2kc0yxc1
,
K
,
Y
,
X
,
C
)();
make_ParallelTensorFunctor
(
f_nkhw_to_nk0hwk1
,
N
,
K
,
Ho
,
Wo
)();
make_ParallelTensorFunctor
(
f_nkhw_to_nk0hwk1
,
N
,
K
,
Ho
x2
,
Wo
x2
)();
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
add_n_k_
2
ho
_2wo
_device_buf
.
ToDevice
(
add_n_k0_
2
ho
_2wo
_k1
.
mData
.
data
());
add_n_k_ho
x2_wox2
_device_buf
.
ToDevice
(
add_n_k0_ho
x2_wox2
_k1
.
mData
.
data
());
#if 1
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
// cdata = 64, BlockSize = 64, 16x8x32x4
...
@@ -141,8 +146,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -141,8 +146,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
3
,
1
>
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
3
*
EPerBlock
,
KPerBlock
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
EPerBlock
,
KPerBlock
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
...
@@ -205,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -205,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
in_n_c0_hi_wi_desc
,
in_n_c0_hi_wi_desc
,
add_n_k0_
2
ho
_2wo
_k1_desc
,
add_n_k0_ho
x2_wox2
_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
...
@@ -215,18 +220,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -215,18 +220,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k_2ho_2wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()));
static_cast
<
TOut
*>
(
out_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()));
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
out_n_k_ho
x2
_wo
x2
_device_buf
.
FromDevice
(
out_n_k0_ho
x2
_wo
x2
_k1
.
mData
.
data
());
#if
0
#if
1
auto
f_nk0hwk1_to_nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nk0hwk1_to_nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_n_k_ho_wo(n, k, ho, wo) =
out_n_k_ho
x2
_wo
x2
(
n
,
k
,
ho
,
wo
)
=
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
out_n_k0_ho
x2
_wo
x2
_k1
(
n
,
k
/
InWeiVectorSize
,
ho
,
wo
,
k
%
InWeiVectorSize
);
};
};
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
make_ParallelTensorFunctor
(
f_nk0hwk1_to_nkhw
,
N
,
K
,
Ho
x2
,
Wo
x2
)();
#endif
#endif
}
}
driver/include/host_conv.hpp
View file @
03aa52bc
...
@@ -41,14 +41,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
...
@@ -41,14 +41,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
}
}
}
}
}
}
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
v
+
add_nkhw
(
n
,
k
,
ho
,
wo
);
index_t
hox2
=
ho
*
2
;
index_t
wox2
=
wo
*
2
;
out_nkhw
(
n
,
k
,
hox2
,
wox2
)
=
v
+
add_nkhw
(
n
,
k
,
hox2
,
wox2
);
out_nkhw
(
n
,
k
,
hox2
,
wox2
+
1
)
=
v
+
add_nkhw
(
n
,
k
,
hox2
,
wox2
+
1
);
out_nkhw
(
n
,
k
,
hox2
+
1
,
wox2
)
=
v
+
add_nkhw
(
n
,
k
,
hox2
+
1
,
wox2
);
out_nkhw
(
n
,
k
,
hox2
+
1
,
wox2
+
1
)
=
v
+
add_nkhw
(
n
,
k
,
hox2
+
1
,
wox2
+
1
);
};
};
auto
f_par
=
make_ParallelTensorFunctor
(
f
,
auto
f_par
=
make_ParallelTensorFunctor
(
f
,
out_nkhw
.
mDesc
.
GetLengths
()[
0
],
out_nkhw
.
mDesc
.
GetLengths
()[
0
],
out_nkhw
.
mDesc
.
GetLengths
()[
1
],
out_nkhw
.
mDesc
.
GetLengths
()[
1
],
out_nkhw
.
mDesc
.
GetLengths
()[
2
],
out_nkhw
.
mDesc
.
GetLengths
()[
2
]
/
2
,
out_nkhw
.
mDesc
.
GetLengths
()[
3
]);
out_nkhw
.
mDesc
.
GetLengths
()[
3
]
/
2
);
f_par
(
std
::
thread
::
hardware_concurrency
());
f_par
(
std
::
thread
::
hardware_concurrency
());
}
}
...
...
driver/src/conv_driver.cpp
View file @
03aa52bc
...
@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
...
@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
...
@@ -700,7 +700,8 @@ int main(int argc, char* argv[])
...
@@ -700,7 +700,8 @@ int main(int argc, char* argv[])
};
};
wei_kcyx
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
#endif
#endif
add_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
1
,
1
},
num_thread
);
// add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
add_nkhw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
}
}
#if 0
#if 0
...
@@ -806,7 +807,7 @@ int main(int argc, char* argv[])
...
@@ -806,7 +807,7 @@ int main(int argc, char* argv[])
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if
0
#if
1
if
(
do_log
)
if
(
do_log
)
{
{
LogRange
(
std
::
cout
<<
"in_nchw : "
,
in_nchw
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"in_nchw : "
,
in_nchw
.
mData
,
","
)
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment