Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
50b96745
Commit
50b96745
authored
Feb 17, 2019
by
Chao Liu
Browse files
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn use khwn for thread C data now
parent
1cb98850
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
22 deletions
+21
-22
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp
+1
-1
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp
+20
-21
No files found.
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp
View file @
50b96745
...
...
@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif
1
#elif
0
// for 1x1, 28x28
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp
View file @
50b96745
...
...
@@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Sequence
<
CPerBlock
,
S
,
R
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_
h
kwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
Ho
PerThread
,
K
PerThread
,
WoPerThread
,
NPerThread
>
{});
constexpr
auto
out_k
h
wn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
PerThread
,
Ho
PerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
...
@@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I1
)
>
{});
#if 0
const auto blockwise_batch_gemm =
...
...
@@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
false,
0,
in_chwn_block_desc.GetStride(I1),
out_
h
kwn_thread_desc.GetStride(I
0
),
out_k
h
wn_thread_desc.GetStride(I
1
),
HoPerBlock,
HoPerThread,
GemmKPerThreadLoop,
...
...
@@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_
h
kwn_thread_desc
.
GetStride
(
I
0
),
out_k
h
wn_thread_desc
.
GetStride
(
I
1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
...
...
@@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
// register
Float
p_out_thread
[
out_
h
kwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_k
h
wn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_
h
kwn_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_k
h
wn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_begin
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
...
...
@@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
#if 0
// for v1 batch-gemm
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
threadwise_4d_tensor_copy(
out_khwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
out_khwn_thread_desc.GetLengths());
#else
for
(
unsigned
ho
=
0
;
ho
<
out_
h
kwn_thread_desc
.
GetLength
(
I0
);
++
ho
)
for
(
unsigned
k
=
0
;
k
<
out_k
h
wn_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
unsigned
k
=
0
;
k
<
out_
h
kwn_thread_desc
.
GetLength
(
I1
);
++
k
)
for
(
unsigned
ho
=
0
;
ho
<
out_k
h
wn_thread_desc
.
GetLength
(
I1
);
++
ho
)
{
for
(
unsigned
wo
=
0
;
wo
<
out_
h
kwn_thread_desc
.
GetLength
(
I2
);
++
wo
)
for
(
unsigned
wo
=
0
;
wo
<
out_k
h
wn_thread_desc
.
GetLength
(
I2
);
++
wo
)
{
for
(
unsigned
n
=
0
;
n
<
out_
h
kwn_thread_desc
.
GetLength
(
I3
);
++
n
)
for
(
unsigned
n
=
0
;
n
<
out_k
h
wn_thread_desc
.
GetLength
(
I3
);
++
n
)
{
const
unsigned
b
=
out_
h
kwn_thread_desc
.
Get1dIndex
(
0
,
0
,
wo
,
n
);
const
unsigned
b
=
out_k
h
wn_thread_desc
.
Get1dIndex
(
0
,
0
,
wo
,
n
);
const
auto
c_thread_mtx_distance
=
blockwise_batch_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
ho
,
k
,
b
);
...
...
@@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
const
unsigned
b_thread
=
c_thread_mtx_begin
.
col
+
c_thread_mtx_distance
.
col
;
const
unsigned
wo_thread
=
b_thread
/
NPerBlock
;
const
unsigned
n_thread
=
b_thread
-
NPerBlock
*
wo_thread
;
const
unsigned
n_thread
=
b_thread
%
NPerBlock
;
p_out_global
[
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread
,
ho_block_data_begin
+
ho_thread
,
wo_block_data_begin
+
wo_thread
,
n_block_data_begin
+
n_thread
)]
=
p_out_thread
[
out_
h
kwn_thread_desc
.
Get1dIndex
(
ho
,
k
,
wo
,
n
)];
p_out_thread
[
out_k
h
wn_thread_desc
.
Get1dIndex
(
k
,
ho
,
wo
,
n
)];
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment