Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd29b09a
Commit
cd29b09a
authored
May 19, 2019
by
Chao Liu
Browse files
refactor
parent
a6b95c39
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
51 deletions
+49
-51
driver/driver.hip.cpp
driver/driver.hip.cpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+47
-49
No files found.
driver/driver.hip.cpp
View file @
cd29b09a
...
@@ -608,11 +608,11 @@ int main(int argc, char* argv[])
...
@@ -608,11 +608,11 @@ int main(int argc, char* argv[])
device_convolution_direct_v2_nchw_kcyx_nkhw
device_convolution_direct_v2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif
0
#elif
1
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
View file @
cd29b09a
...
@@ -100,8 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -100,8 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
...
@@ -359,13 +358,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -359,13 +358,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// perfect forwarding.
// fwd do nothing but perfect forwarding.
// Using this trick to
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// make this lambda a generic lambda, so it won't be compiled until
// begin instantiated here
// instantiated
static_assert
(
static_assert
(
(
f
_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
(
f
wd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
"wrong!"
);
// output is a 10d tensor
// output is a 10d tensor
...
@@ -373,38 +371,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -373,38 +371,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
wd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
make_ConstantTensorDescriptor
(
Sequence
<
N
/
f_dummy
(
N1
*
N2
),
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{})
N1
,
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
N2
,
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
K
/
(
K1
*
K2
),
K1
,
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
K2
,
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
Ho
,
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
Wo
/
(
W1
*
W2
),
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
W1
,
W2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"
out_k_h_w_n_thread_desc");
"a:
out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "
a:
out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_
n_
global_desc,
print_ConstantTensorDescriptor(out_
n_
k_h_w_global_desc,
"
out_k_h_w_
n_
global_desc");
"a:
out_
n_
k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "
a:
out_10d_global_desc");
}
}
#endif
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
...
@@ -421,8 +414,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -421,8 +414,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{});
}).
else_
([
&
](
auto
f
_dummy
)
{
}).
else_
([
&
](
auto
f
wd
)
{
static_assert
(
f
_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f
wd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -431,29 +424,34 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -431,29 +424,34 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f
_dummy
(
W2
*
W3
);
constexpr
index_t
W1
=
WoPerBlock
/
f
wd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_global_desc
=
Sequence
<
N
/
N1
,
N1
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
>
{});
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_thread_desc
=
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"b: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment