Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8e35a579
Commit
8e35a579
authored
Jan 06, 2021
by
Chao Liu
Browse files
refactor
parent
89123dd7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
31 deletions
+33
-31
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+33
-31
No files found.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
8e35a579
...
@@ -131,7 +131,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -131,7 +131,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
make_multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_block_copy
=
auto
a_block
wise
_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
...
@@ -160,7 +160,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -160,7 +160,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_block_copy
=
auto
b_block
wise
_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
...
@@ -219,7 +219,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -219,7 +219,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{});
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{});
const
auto
block_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
block
wise
_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
...
@@ -256,14 +256,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -256,14 +256,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
#if 1
#if 1
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block
wise
_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block
wise
_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
a_block
wise
_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
);
b_block
wise
_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
);
}
}
#endif
#endif
#if 1
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
...
@@ -275,65 +276,66 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -275,65 +276,66 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
k_block_data_begin
+=
2
*
KPerBlock
)
k_block_data_begin
+=
2
*
KPerBlock
)
{
{
// even iteration
// even iteration
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block
wise
_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block
wise
_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
block
wise
_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
a_block
wise
_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_odd
);
b_block
wise
_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_odd
);
// odd iteration
// odd iteration
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block
wise
_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block
wise
_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
block
wise
_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
a_block
wise
_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_even
);
b_block
wise
_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_even
);
}
}
#endif
#if 1
#if 1
// LDS double buffer: tail
// LDS double buffer: tail
{
{
if
constexpr
(
IsEvenNumberKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
IsEvenNumberKBlockLoop
)
// if has 2 iteration left
{
{
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block
wise
_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block
wise
_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
block
wise
_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_block
wise
_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
b_block
wise
_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
block_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
block
wise
_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
p_c_thread
);
}
}
...
@@ -342,7 +344,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -342,7 +344,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
block
wise
_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
}
}
#endif
#endif
...
@@ -361,7 +363,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -361,7 +363,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// calculate origin of thread input tensor on global memory
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
block_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
block
wise
_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment