Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c883faa
Commit
0c883faa
authored
Mar 31, 2021
by
Jing Zhang
Browse files
fixed outputpad
parent
351c227a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
58 deletions
+61
-58
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
...tion_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
+58
-54
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-4
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
View file @
0c883faa
...
@@ -75,8 +75,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -75,8 +75,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
OutRightPadH
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
-
Ho
;
const
auto
Hop
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
;
const
auto
OutRightPadW
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
-
Wo
;
const
auto
Wop
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
;
const
auto
OutRightPadH
=
Hop
-
Ho
;
const
auto
OutRightPadW
=
Wop
-
Wo
;
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
...
@@ -111,8 +114,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -111,8 +114,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
p
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_embed_transform
(
make_tuple
(
X
,
Wo
p
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -120,13 +123,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -120,13 +123,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
in_n_c_y_ho_x_wo_global_desc
,
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Ho
p
),
make_pass_through_transform
(
Wo
)),
make_pass_through_transform
(
Wo
p
)),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// output tensor
// output tensor
const
auto
out_k_n_ho_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_k_n_ho
p
_wo
p
_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
...
@@ -137,12 +140,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -137,12 +140,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
auto
E
=
C
*
Y
*
X
;
const
auto
E
=
C
*
Y
*
X
;
const
int
Ho_new
=
out_k_n_ho_wo_global_desc
.
GetLength
(
I2
);
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
const
int
Wo_new
=
out_k_n_ho_wo_global_desc
.
GetLength
(
I3
);
std
::
cerr
<<
"Ho_new = "
<<
Ho_new
<<
" Wo_new = "
<<
Wo_new
<<
std
::
endl
;
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Ho
_new
%
HoPerBlock
)
==
0
&&
(
Wo
_new
%
WoPerBlock
)
==
0
&&
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Ho
p
%
HoPerBlock
)
==
0
&&
(
Wo
p
%
WoPerBlock
)
==
0
&&
(
E
%
EPerBlock
)
==
0
))
(
E
%
EPerBlock
)
==
0
))
{
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
...
@@ -190,7 +190,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -190,7 +190,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho
p
_wo
p
_global_desc
),
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
...
@@ -221,7 +221,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -221,7 +221,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
)
*
N
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
p
/
HoPerBlock
)
*
(
Wo
p
/
WoPerBlock
)
*
N
;
const
bool
has_main_k_block_loop
=
(
E
+
EPerBlock
)
/
(
2
*
EPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
E
+
EPerBlock
)
/
(
2
*
EPerBlock
)
>
1
;
...
@@ -243,15 +243,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -243,15 +243,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
{
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
decltype
(
wei_e_k_global_desc
),
run_gridwise_operation
<
gridwise_gemm
,
const
FloatAB
*
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
const
FloatAB
*
,
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
integral_constant
<
bool
,
true
>
,
FloatC
*
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
@@ -262,22 +263,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -262,22 +263,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
out_k_n_ho
p
_wo
p
_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
integral_constant
<
bool
,
true
>
{});
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
decltype
(
wei_e_k_global_desc
),
run_gridwise_operation
<
gridwise_gemm
,
const
FloatAB
*
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
const
FloatAB
*
,
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
integral_constant
<
bool
,
true
>
,
FloatC
*
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
@@ -288,22 +290,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -288,22 +290,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
out_k_n_ho
p
_wo
p
_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
decltype
(
wei_e_k_global_desc
),
run_gridwise_operation
<
gridwise_gemm
,
const
FloatAB
*
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
const
FloatAB
*
,
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
integral_constant
<
bool
,
false
>
,
FloatC
*
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
@@ -314,22 +317,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -314,22 +317,23 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
out_k_n_ho
p
_wo
p
_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
decltype
(
wei_e_k_global_desc
),
run_gridwise_operation
<
gridwise_gemm
,
const
FloatAB
*
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
const
FloatAB
*
,
FloatC
*
,
decltype
(
out_k_n_hop_wop_global_desc
),
integral_constant
<
bool
,
false
>
,
FloatC
*
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
@@ -340,7 +344,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -340,7 +344,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
out_k_n_ho
p
_wo
p
_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
...
...
driver/src/conv_driver.cpp
View file @
0c883faa
...
@@ -36,11 +36,10 @@ int main(int argc, char* argv[])
...
@@ -36,11 +36,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
1
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
//constexpr index_t HI = 540;
constexpr
index_t
HI
=
540
;
constexpr
index_t
HI
=
544
;
constexpr
index_t
WI
=
960
;
constexpr
index_t
WI
=
960
;
constexpr
index_t
K
=
16
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
...
@@ -107,7 +106,7 @@ int main(int argc, char* argv[])
...
@@ -107,7 +106,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
HI
=
540
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment