Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5b242405
Commit
5b242405
authored
Mar 18, 2021
by
Chao Liu
Browse files
refactor
parent
f1403dac
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
304 additions
and
253 deletions
+304
-253
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+4
-4
composable_kernel/include/gridwise_operation_wrapper.hpp
composable_kernel/include/gridwise_operation_wrapper.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+229
-222
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+18
-7
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+33
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+14
-14
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
5b242405
...
...
@@ -17,8 +17,8 @@ template <index_t BlockSize,
index_t
WoPerBlock
,
index_t
EPerBlock
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
H
o
PerThread
,
index_t
W
o
PerThread
,
index_t
EPerThread
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
...
...
@@ -178,8 +178,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
WoPerBlock
,
EPerBlock
,
KPerThread
,
HPerThread
,
WPerThread
,
H
o
PerThread
,
W
o
PerThread
,
EPerThread
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
...
...
composable_kernel/include/gridwise_operation_wrapper.hpp
View file @
5b242405
...
...
@@ -3,10 +3,10 @@
template
<
typename
GridwiseOp
,
typename
...
Xs
>
__global__
void
#if
0
__launch_bounds__(
256, 2
)
#if
CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
run_gridwise_operation
(
Xs
...
xs
)
run_gridwise_operation
(
Xs
...
xs
)
{
GridwiseOp
{}.
Run
(
xs
...);
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
5b242405
...
...
@@ -19,12 +19,12 @@ template <index_t BlockSize,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
index_t
KPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
H
o
PerBlock
,
index_t
W
o
PerBlock
,
index_t
EPerBlock
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
H
o
PerThread
,
index_t
W
o
PerThread
,
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
...
...
@@ -69,9 +69,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_h_w_global_desc
,
const
BGlobalDesc
&
b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_h_w_global_desc
,
const
CGlobalDesc
&
c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
...
@@ -85,35 +85,35 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
auto
E
=
a_e_k_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_e_k_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_e_n_h_w_global_desc
.
GetLength
(
I1
);
const
auto
H
=
b_e_n_h_w_global_desc
.
GetLength
(
I2
);
const
auto
W
=
b_e_n_h_w_global_desc
.
GetLength
(
I3
);
const
auto
N
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I1
);
const
auto
H
o
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I2
);
const
auto
W
o
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I3
);
// divide block work by [M, N]
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num;
const auto h
o
_block_work_num = H
o
/ Number<H
o
PerBlock>{};
const auto w
o
_block_work_num = W
o
/ Number<W
o
PerBlock>{};
const auto hw
o
_block_work_num = h
o
_block_work_num * w
o
_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw
o
_block_work_num;
const index_t hw
o
_block_work_id = get_block_1d_id() - k_block_work_id * hw
o
_block_work_num;
const index_t h
o
_block_work_id = hw
o
_block_work_id / w
o
_block_work_num;
const index_t w
o
_block_work_id = hw
o
_block_work_id - h
o
_block_work_id * w
o
_block_work_num;
#else
// Hack: this force result into SGPR
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
h_block_work_num
=
__builtin_amdgcn_readfirstlane
(
H
/
HPerBlock
);
const
index_t
w_block_work_num
=
__builtin_amdgcn_readfirstlane
(
W
/
WPerBlock
);
const
index_t
hw_block_work_num
=
h_block_work_num
*
w_block_work_num
;
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
h
o
_block_work_num
=
__builtin_amdgcn_readfirstlane
(
H
o
/
H
o
PerBlock
);
const
index_t
w
o
_block_work_num
=
__builtin_amdgcn_readfirstlane
(
W
o
/
W
o
PerBlock
);
const
index_t
hw
o
_block_work_num
=
h
o
_block_work_num
*
w
o
_block_work_num
;
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
hw_block_work_num
);
const
index_t
hw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw_block_work_num
;
const
index_t
h_block_work_id
=
__builtin_amdgcn_readfirstlane
(
hw_block_work_id
/
w_block_work_num
);
const
index_t
w_block_work_id
=
hw_block_work_id
-
h_block_work_id
*
w_block_work_num
;
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
hw
o
_block_work_num
);
const
index_t
hw
o
_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw
o
_block_work_num
;
const
index_t
h
o
_block_work_id
=
__builtin_amdgcn_readfirstlane
(
hw
o
_block_work_id
/
w
o
_block_work_num
);
const
index_t
w
o
_block_work_id
=
hw
o
_block_work_id
-
h
o
_block_work_id
*
w
o
_block_work_num
;
#endif
// lds max alignment
...
...
@@ -127,39 +127,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_e_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}));
constexpr
auto
b_e_n_ho_wo_block_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_h_w_thread_desc
=
constexpr
auto
c_k_n_h
o
_w
o
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_h_w_block_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
KPerThread
,
// KPerThreadSubC
HPerThread
,
// HPerThreadSubC
WPerThread
,
// WPerThreadSubC
EPerThread
,
// EPerThreadLoop
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
// KPerThreadSubC
HoPerThread
,
// HoPerThreadSubC
WoPerThread
,
// WoPerThreadSubC
EPerThread
,
// EPerThreadLoop
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
h_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
w_thread_id
=
c_thread_mtx_index
.
w
;
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
h
o
_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
w
o
_thread_id
=
c_thread_mtx_index
.
w
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h_block_data_on_global
=
h_block_work_id
*
HPerBlock
;
const
index_t
w_block_data_on_global
=
w_block_work_id
*
WPerBlock
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h
o
_block_data_on_global
=
h
o
_block_work_id
*
H
o
PerBlock
;
const
index_t
w
o
_block_data_on_global
=
w
o
_block_work_id
*
W
o
PerBlock
;
const
index_t
h_thread_data_on_global
=
h_block_data_on_global
+
h_thread_id
*
HPerThread
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
const
index_t
ho_thread_data_on_global
=
ho_block_data_on_global
+
ho_thread_id
*
HoPerThread
;
const
index_t
wo_thread_data_on_global
=
wo_block_data_on_global
+
wo_thread_id
*
WoPerThread
;
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -190,26 +194,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_e_k_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
b_e_n_h_w_thread_desc
=
constexpr
auto
b_e_n_h
o
_w
o
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
,
decltype
(
b_e_n_h_w_global_desc
),
decltype
(
b_e_n_h_w_thread_desc
),
Sequence
<
EPerBlock
,
1
,
HPerThread
,
WPerThread
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
1
,
true
>
(
b_e_n_h_w_global_desc
,
make_multi_index
(
0
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
,
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
1
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -218,26 +221,26 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float
*
p_a_block_double
=
p_shared_block
;
// register allocation for output
AccFloat
p_c_thread
[
c_k_n_h_w_thread_desc
.
GetElementSpaceSize
()];
AccFloat
p_c_thread
[
c_k_n_h
o
_w
o
_thread_desc
.
GetElementSpaceSize
()];
// zero out threadwise output
threadwise_matrix_set_zero_v3
(
c_k_n_h_w_thread_desc
,
p_c_thread
);
threadwise_matrix_set_zero_v3
(
c_k_n_h
o
_w
o
_thread_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_e_n_h_w_global_iterator_hacks
=
BGlobalIteratorHacks
{};
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_e_n_h
o
_w
o
_global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_e_n_h_w_global_move_slice_window_iterator_hack
=
constexpr
auto
b_e_n_h
o
_w
o
_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_thread_space_size
=
b_e_n_h_w_thread_desc
.
GetElementSpaceSize
();
constexpr
auto
b_thread_space_size
=
b_e_n_h
o
_w
o
_thread_desc
.
GetElementSpaceSize
();
Float
p_b_thread
[
b_thread_space_size
*
2
];
Float
*
p_b_thread_double
=
p_b_thread
;
...
...
@@ -246,12 +249,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_double
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_block_desc
,
p_a_block_double
);
}
...
...
@@ -276,7 +279,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
...
...
@@ -285,12 +288,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_odd
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_thread_even
,
p_c_thread
);
...
...
@@ -303,7 +306,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
...
...
@@ -311,12 +314,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_even
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_thread_odd
,
p_c_thread
);
...
...
@@ -335,7 +338,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
__syncthreads
();
...
...
@@ -343,12 +346,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_double
+
b_thread_space_size
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_thread_double
,
p_c_thread
);
...
...
@@ -375,8 +378,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr
auto
c_k_n_h_w_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
// hack to control index calculation when iterating over c_k_n_h
o
_w
o
_global tensor
constexpr
auto
c_k_n_h
o
_w
o
_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
...
...
@@ -384,9 +387,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
Float
,
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_global_desc
),
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
,
decltype
(
c_k_n_h
o
_w
o
_thread_desc
),
decltype
(
c_k_n_h
o
_w
o
_global_desc
),
Sequence
<
KPerThread
,
1
,
H
o
PerThread
,
W
o
PerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
...
...
@@ -395,15 +398,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
make_multi_index
(
k_thread_data_on_global
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
))
.
Run
(
c_k_n_h_w_thread_desc
,
k_thread_data_on_global
,
0
,
h
o
_thread_data_on_global
,
w
o
_thread_data_on_global
))
.
Run
(
c_k_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_c_thread
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
c_k_n_h_w_global_tensor_iterator_hacks
);
c_k_n_h
o
_w
o
_global_tensor_iterator_hacks
);
}
#endif
}
...
...
@@ -412,9 +415,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_h_w_global_desc
,
const
BGlobalDesc
&
b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_h_w_global_desc
,
const
CGlobalDesc
&
c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
...
...
@@ -425,9 +428,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
@@ -438,22 +441,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_e_n_h_w_global_desc
,
const
BGlobalDesc
*
p_b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_k_n_h_w_global_desc
,
const
CGlobalDesc
*
p_c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
p_a_e_k_global_desc
;
const
auto
b_e_n_h_w_global_desc
=
*
p_b_e_n_h_w_global_desc
;
const
auto
c_k_n_h_w_global_desc
=
*
p_c_k_n_h_w_global_desc
;
const
auto
a_e_k_global_desc
=
*
p_a_e_k_global_desc
;
const
auto
b_e_n_h
o
_w
o
_global_desc
=
*
p_b_e_n_h
o
_w
o
_global_desc
;
const
auto
c_k_n_h
o
_w
o
_global_desc
=
*
p_c_k_n_h
o
_w
o
_global_desc
;
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -463,24 +466,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
void
*
p_b_e_n_h_w_global_desc
,
const
void
*
p_b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
void
*
p_c_k_n_h_w_global_desc
,
const
void
*
p_c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_e_k_global_desc
);
const
auto
b_e_n_h_w_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_e_n_h_w_global_desc
);
const
auto
c_k_n_h_w_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h_w_global_desc
);
const
auto
b_e_n_h
o
_w
o
_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_e_n_h
o
_w
o
_global_desc
);
const
auto
c_k_n_h
o
_w
o
_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h
o
_w
o
_global_desc
);
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -495,12 +498,12 @@ template <index_t BlockSize,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
index_t
KPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
H
o
PerBlock
,
index_t
W
o
PerBlock
,
index_t
EPerBlock
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
H
o
PerThread
,
index_t
W
o
PerThread
,
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
...
...
@@ -548,9 +551,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_h_w_global_desc
,
const
BGlobalDesc
&
b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_h_w_global_desc
,
const
CGlobalDesc
&
c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
...
@@ -564,34 +567,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const
auto
E
=
a_e_k_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_e_k_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_e_n_h_w_global_desc
.
GetLength
(
I1
);
const
auto
H
=
b_e_n_h_w_global_desc
.
GetLength
(
I2
);
const
auto
W
=
b_e_n_h_w_global_desc
.
GetLength
(
I3
);
const
auto
N
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I1
);
const
auto
H
o
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I2
);
const
auto
W
o
=
b_e_n_h
o
_w
o
_global_desc
.
GetLength
(
I3
);
// divide block work by [M, N]
#if 1
const
auto
k_block_work_num
=
K
/
Number
<
KPerBlock
>
{};
const
auto
h_block_work_num
=
H
/
Number
<
HPerBlock
>
{};
const
auto
w_block_work_num
=
W
/
Number
<
WPerBlock
>
{};
const
auto
hw_block_work_num
=
h_block_work_num
*
w_block_work_num
;
const
auto
k_block_work_num
=
K
/
Number
<
KPerBlock
>
{};
const
auto
h
o
_block_work_num
=
H
o
/
Number
<
H
o
PerBlock
>
{};
const
auto
w
o
_block_work_num
=
W
o
/
Number
<
W
o
PerBlock
>
{};
const
auto
hw
o
_block_work_num
=
h
o
_block_work_num
*
w
o
_block_work_num
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
hw_block_work_num
;
const
index_t
hw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw_block_work_num
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
hw
o
_block_work_num
;
const
index_t
hw
o
_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw
o
_block_work_num
;
#else
// Hack: this force result into SGPR
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
h_block_work_num
=
__builtin_amdgcn_readfirstlane
(
H
/
HPerBlock
);
const
index_t
w_block_work_num
=
__builtin_amdgcn_readfirstlane
(
W
/
WPerBlock
);
const
index_t
hw_block_work_num
=
h_block_work_num
*
w_block_work_num
;
const
index_t
k_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
h
o
_block_work_num
=
__builtin_amdgcn_readfirstlane
(
H
o
/
H
o
PerBlock
);
const
index_t
w
o
_block_work_num
=
__builtin_amdgcn_readfirstlane
(
W
o
/
W
o
PerBlock
);
const
index_t
hw
o
_block_work_num
=
h
o
_block_work_num
*
w
o
_block_work_num
;
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
hw_block_work_num
);
const
index_t
hw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw_block_work_num
;
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
hw
o
_block_work_num
);
const
index_t
hw
o
_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw
o
_block_work_num
;
#endif
const
index_t
h_block_work_id
=
hw_block_work_id
/
w_block_work_num
;
const
index_t
w_block_work_id
=
hw_block_work_id
-
h_block_work_id
*
w_block_work_num
;
const
index_t
h
o
_block_work_id
=
hw
o
_block_work_id
/
w
o
_block_work_num
;
const
index_t
w
o
_block_work_id
=
hw
o
_block_work_id
-
h
o
_block_work_id
*
w
o
_block_work_num
;
// lds max alignment
constexpr
auto
max_lds_align
=
...
...
@@ -607,39 +610,43 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_e_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}));
constexpr
auto
b_e_n_ho_wo_block_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_h_w_thread_desc
=
constexpr
auto
c_k_n_h
o
_w
o
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_h_w_block_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
KPerThread
,
// KPerThreadSubC
HPerThread
,
// HPerThreadSubC
WPerThread
,
// WPerThreadSubC
EPerThread
,
// EPerThreadLoop
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
// KPerThreadSubC
HoPerThread
,
// HoPerThreadSubC
WoPerThread
,
// WoPerThreadSubC
EPerThread
,
// EPerThreadLoop
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
h_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
w_thread_id
=
c_thread_mtx_index
.
w
;
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
h
o
_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
w
o
_thread_id
=
c_thread_mtx_index
.
w
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h_block_data_on_global
=
h_block_work_id
*
HPerBlock
;
const
index_t
w_block_data_on_global
=
w_block_work_id
*
WPerBlock
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h
o
_block_data_on_global
=
h
o
_block_work_id
*
H
o
PerBlock
;
const
index_t
w
o
_block_data_on_global
=
w
o
_block_work_id
*
W
o
PerBlock
;
const
index_t
h_thread_data_on_global
=
h_block_data_on_global
+
h_thread_id
*
HPerThread
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
const
index_t
ho_thread_data_on_global
=
ho_block_data_on_global
+
ho_thread_id
*
HoPerThread
;
const
index_t
wo_thread_data_on_global
=
wo_block_data_on_global
+
wo_thread_id
*
WoPerThread
;
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -670,16 +677,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
a_e_k_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
b_e_n_h_w_thread_desc
=
constexpr
auto
b_e_n_h
o
_w
o
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
H
o
PerThread
>
{},
Number
<
W
o
PerThread
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
,
decltype
(
b_e_n_h_w_global_desc
),
decltype
(
b_e_n_h_w_thread_desc
),
Sequence
<
EPerBlock
,
1
,
HPerThread
,
WPerThread
>
,
decltype
(
b_e_n_h
o
_w
o
_global_desc
),
decltype
(
b_e_n_h
o
_w
o
_thread_desc
),
Sequence
<
EPerBlock
,
1
,
H
o
PerThread
,
W
o
PerThread
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
// BBlockTransferSrcAccessOrder,
3
,
// BBlockTransferSrcVectorDim,
1
,
// BBlockTransferSrcScalarPerVector,
...
...
@@ -687,31 +694,31 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
1
,
true
>
(
b_e_n_h_w_global_desc
,
make_multi_index
(
0
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
));
true
>
(
b_e_n_h
o
_w
o
_global_desc
,
make_multi_index
(
0
,
0
,
h
o
_thread_data_on_global
,
w
o
_thread_data_on_global
));
Float
*
p_a_block
=
p_shared_block
;
// register allocation for output
AccFloat
p_c_thread
[
c_k_n_h_w_thread_desc
.
GetElementSpaceSize
()];
AccFloat
p_c_thread
[
c_k_n_h
o
_w
o
_thread_desc
.
GetElementSpaceSize
()];
// zero out threadwise output
threadwise_matrix_set_zero_v3
(
c_k_n_h_w_thread_desc
,
p_c_thread
);
threadwise_matrix_set_zero_v3
(
c_k_n_h
o
_w
o
_thread_desc
,
p_c_thread
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_e_n_h_w_global_iterator_hacks
=
BGlobalIteratorHacks
{};
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_e_n_h
o
_w
o
_global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_e_n_h_w_global_move_slice_window_iterator_hack
=
constexpr
auto
b_e_n_h
o
_w
o
_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_thread_space_size
=
b_e_n_h_w_thread_desc
.
GetElementSpaceSize
();
constexpr
auto
b_thread_space_size
=
b_e_n_h
o
_w
o
_thread_desc
.
GetElementSpaceSize
();
Float
p_b_thread
[
b_thread_space_size
*
2
];
Float
*
p_b_thread_double
=
p_b_thread
;
...
...
@@ -720,12 +727,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_double
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_a_block
);
}
...
...
@@ -745,15 +752,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_odd
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
...
...
@@ -763,15 +770,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_block_data_begin
+=
EPerBlock
;
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_even
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
...
...
@@ -787,15 +794,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_h
o
_w
o
_global_desc
,
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_h_w_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
b_e_n_h_w_thread_desc
,
b_e_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread_double
+
b_thread_space_size
,
b_e_n_h_w_global_iterator_hacks
);
b_e_n_h
o
_w
o
_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
...
...
@@ -824,8 +831,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr
auto
c_k_n_h_w_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
// hack to control index calculation when iterating over c_k_n_h
o
_w
o
_global tensor
constexpr
auto
c_k_n_h
o
_w
o
_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
...
...
@@ -833,9 +840,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
Float
,
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_global_desc
),
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
,
decltype
(
c_k_n_h
o
_w
o
_thread_desc
),
decltype
(
c_k_n_h
o
_w
o
_global_desc
),
Sequence
<
KPerThread
,
1
,
H
o
PerThread
,
W
o
PerThread
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
// CThreadTransferSrcDstAccessOrder
3
,
// CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector
,
...
...
@@ -844,15 +851,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
make_multi_index
(
k_thread_data_on_global
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
))
.
Run
(
c_k_n_h_w_thread_desc
,
k_thread_data_on_global
,
0
,
h
o
_thread_data_on_global
,
w
o
_thread_data_on_global
))
.
Run
(
c_k_n_h
o
_w
o
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_c_thread
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
c_k_n_h_w_global_tensor_iterator_hacks
);
c_k_n_h
o
_w
o
_global_tensor_iterator_hacks
);
}
#endif
}
...
...
@@ -861,9 +868,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_h_w_global_desc
,
const
BGlobalDesc
&
b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_h_w_global_desc
,
const
CGlobalDesc
&
c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
...
...
@@ -874,9 +881,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
@@ -887,22 +894,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_e_n_h_w_global_desc
,
const
BGlobalDesc
*
p_b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_k_n_h_w_global_desc
,
const
CGlobalDesc
*
p_c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
p_a_e_k_global_desc
;
const
auto
b_e_n_h_w_global_desc
=
*
p_b_e_n_h_w_global_desc
;
const
auto
c_k_n_h_w_global_desc
=
*
p_c_k_n_h_w_global_desc
;
const
auto
a_e_k_global_desc
=
*
p_a_e_k_global_desc
;
const
auto
b_e_n_h
o
_w
o
_global_desc
=
*
p_b_e_n_h
o
_w
o
_global_desc
;
const
auto
c_k_n_h
o
_w
o
_global_desc
=
*
p_c_k_n_h
o
_w
o
_global_desc
;
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -912,24 +919,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_e_k_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
void
*
p_b_e_n_h_w_global_desc
,
const
void
*
p_b_e_n_h
o
_w
o
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
void
*
p_c_k_n_h_w_global_desc
,
const
void
*
p_c_k_n_h
o
_w
o
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_e_k_global_desc
);
const
auto
b_e_n_h_w_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_e_n_h_w_global_desc
);
const
auto
c_k_n_h_w_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h_w_global_desc
);
const
auto
b_e_n_h
o
_w
o
_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_e_n_h
o
_w
o
_global_desc
);
const
auto
c_k_n_h
o
_w
o
_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h
o
_w
o
_global_desc
);
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_h_w_global_desc
,
b_e_n_h
o
_w
o
_global_desc
,
p_b_global
,
c_k_n_h_w_global_desc
,
c_k_n_h
o
_w
o
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
5b242405
...
...
@@ -7,6 +7,10 @@
#endif
#include "bfloat16_dev.hpp"
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if 1
#define CK_AMD_GPU_GFX906 1
#elif 0
...
...
@@ -15,22 +19,29 @@
#define CK_AMD_GPU_GFX1030 1
#endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1
#endif
// buffer resourse
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 1
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
5b242405
...
...
@@ -133,6 +133,39 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 1
// cdata = 64, BlockSize 64, 16x256x2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
5b242405
...
...
@@ -70,15 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HPerBlock
=
16
;
constexpr
index_t
WPerBlock
=
16
;
constexpr
index_t
CYX
PerBlock
=
4
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
H
o
PerBlock
=
16
;
constexpr
index_t
W
o
PerBlock
=
16
;
constexpr
index_t
E
PerBlock
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
HPerThread
=
2
;
constexpr
index_t
WPerThread
=
2
;
constexpr
index_t
CYX
PerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
H
o
PerThread
=
2
;
constexpr
index_t
W
o
PerThread
=
2
;
constexpr
index_t
E
PerThread
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
...
...
@@ -97,13 +97,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
TDevice
,
TDevice
,
KPerBlock
,
HPerBlock
,
WPerBlock
,
CYX
PerBlock
,
H
o
PerBlock
,
W
o
PerBlock
,
E
PerBlock
,
KPerThread
,
HPerThread
,
WPerThread
,
CYX
PerThread
,
H
o
PerThread
,
W
o
PerThread
,
E
PerThread
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
...
...
driver/src/conv_driver.cpp
View file @
5b242405
...
...
@@ -34,8 +34,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads
= Sequence<0, 0>;
using RightPads
= Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
...
...
@@ -736,7 +736,7 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
acc_data_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment