Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
91ef99a7
Commit
91ef99a7
authored
Mar 14, 2021
by
root
Browse files
double buffer b with bug
parent
88d51698
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
20 deletions
+54
-20
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+49
-15
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+4
-4
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
91ef99a7
...
@@ -147,7 +147,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -147,7 +147,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
// loop over k
// loop over k
for
(
index_t
cyx_begin
=
0
;
cyx_begin
<
CYXPerBlock
;
cyx_begin
+=
CYXPerThreadLoop
)
for
(
index_t
cyx_begin
=
0
;
cyx_begin
<
CYXPerBlock
;
cyx_begin
+=
CYXPerThreadLoop
)
{
{
#if
1
#if
0
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
mMyThreadOffsetA,
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, 0)));
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, 0)));
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
91ef99a7
...
@@ -219,9 +219,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -219,9 +219,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_cyx_k_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_cyx_k_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_cyx_n_h_w_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_a_block_double
=
p_shared_block
;
// register allocation for output
// register allocation for output
...
@@ -235,8 +232,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -235,8 +232,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// zero out threadwise output
// zero out threadwise output
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
CYXPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
CYXPerBlock
,
0
);
//
constexpr auto b_
block
_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0);
constexpr
auto
b_
thread
_slice_copy_step
=
make_multi_index
(
CYXPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
...
@@ -249,8 +246,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -249,8 +246,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
constexpr
auto
b_cyx_n_h_w_global_move_slice_window_iterator_hack
=
constexpr
auto
b_cyx_n_h_w_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
Float
p_b_thread
[
b_cyx_n_h_w_thread_desc
.
GetElementSpaceSize
()];
constexpr
auto
b_thread_space_size
=
b_cyx_n_h_w_thread_desc
.
GetElementSpaceSize
();
Float
p_b_thread
[
b_thread_space_size
*
2
];
Float
*
p_b_thread_double
=
p_b_thread
;
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
...
@@ -260,27 +259,32 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -260,27 +259,32 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
p_b_global
,
p_b_global
,
b_cyx_n_h_w_thread_desc
,
b_cyx_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread
,
p_b_thread
_double
,
b_cyx_n_h_w_global_iterator_hacks
);
b_cyx_n_h_w_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
);
#if 0
__syncthreads();
__syncthreads();
//blockwise_gemm.Run(p_a_block_double, p_b_thread, p_c_thread);
//blockwise_gemm.Run(p_a_block_double, p_b_thread
_double
, p_c_thread);
index_t sum = 0;
index_t sum = 0;
for(index_t i = 0; i < b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); i++)
for(index_t i = 0; i < b_cyx_n_h_w_thread_desc.GetElementSpaceSize(); i++)
sum += p_b_thread[i];
sum += p_b_thread[i];
p_c_thread[0] = get_thread_local_1d_id() * 10000 + sum;
p_c_thread[0] = get_thread_local_1d_id() * 10000 + sum;
#endif
}
}
#if
0
#if
1
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_a_block_even
=
p_a_block_double
;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_thread_even
=
p_b_thread_double
;
Float
*
p_b_thread_odd
=
p_b_thread_double
+
b_thread_space_size
;
index_t
b_block_data_begin
=
0
;
index_t
b_block_data_begin
=
0
;
...
@@ -293,14 +297,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -293,14 +297,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_odd
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_
block
_even, p_c_thread);
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_
thread
_even
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_odd
);
...
@@ -309,14 +323,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -309,14 +323,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_cyx_k_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_cyx_k_global_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_even
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_
block
_odd, p_c_thread);
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_
thread
_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_even
);
...
@@ -332,13 +356,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -332,13 +356,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_double
+
b_thread_space_size
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_
block
_double, p_c_thread);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_
thread
_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
+
a_block_space_size
);
...
@@ -347,7 +381,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -347,7 +381,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_
block
_double + b_
block
_space_size,
p_b_
thread
_double
+
b_
thread
_space_size
,
p_c_thread
);
p_c_thread
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
...
@@ -355,7 +389,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -355,7 +389,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_
block
_double, p_c_thread);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_
thread
_double
,
p_c_thread
);
}
}
#endif
#endif
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
91ef99a7
...
@@ -73,20 +73,20 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -73,20 +73,20 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HPerBlock
=
8
;
constexpr
index_t
HPerBlock
=
8
;
constexpr
index_t
WPerBlock
=
8
;
constexpr
index_t
WPerBlock
=
8
;
constexpr
index_t
CYXPerBlock
=
4
*
3
*
3
;
constexpr
index_t
CYXPerBlock
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HPerThread
=
1
;
constexpr
index_t
HPerThread
=
1
;
constexpr
index_t
WPerThread
=
1
;
constexpr
index_t
WPerThread
=
1
;
constexpr
index_t
CYXPerThread
=
4
*
3
*
3
;
constexpr
index_t
CYXPerThread
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
9
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
36
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment