Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
70d06fa9
Commit
70d06fa9
authored
Nov 30, 2020
by
Chao Liu
Browse files
fixing useless instruction issue
parent
7733dd88
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
285 additions
and
187 deletions
+285
-187
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+111
-53
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+97
-104
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+63
-17
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
...ble_kernel/include/utility/in_memory_operation.amd.hpp.in
+14
-13
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
70d06fa9
...
@@ -173,8 +173,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -173,8 +173,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// GEMM
// GEMM
#if 1
using
gridwise_gemm
=
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
r2
<
BlockSize
,
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
Float
,
AccFloat
,
AccFloat
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
...
@@ -210,7 +211,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -210,7 +211,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const
bool
is_even_number_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
const
bool
is_even_number_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
const
auto
kernel_even
=
if
(
is_even_number_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
*
,
...
@@ -220,7 +223,22 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -220,7 +223,22 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
Float
*
,
Float
*
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
const
auto
kernel_odd
=
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
const
Float
*
,
...
@@ -230,9 +248,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -230,9 +248,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
Float
*
,
Float
*
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
if
(
is_even_number_k_block_loop
)
launch_kernel
(
kernel
,
{
launch_kernel
(
kernel_even
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -243,11 +259,54 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -243,11 +259,54 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_in_global
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
tru
e
>
{});
integral_constant
<
bool
,
fals
e
>
{});
}
}
else
#else
{
using
gridwise_gemm
=
launch_kernel
(
kernel_odd
,
GridwiseDynamicGemm_km_kn_mn_v2
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -257,9 +316,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -257,9 +316,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
p_out_global
);
integral_constant
<
bool
,
false
>
{});
#endif
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
70d06fa9
...
@@ -42,7 +42,7 @@ template <index_t BlockSize,
...
@@ -42,7 +42,7 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseDynamicGemm_km_kn_mn_v1
r1
struct
GridwiseDynamicGemm_km_kn_mn_v1
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
...
@@ -90,11 +90,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -90,11 +90,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
const
index_t
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
index_t
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
// divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock;
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
const index_t n_block_work_num = N / NPerBlock;
#else
// Hack: this force result into SGPR
const
index_t
m_block_work_num
=
__builtin_amdgcn_readfirstlane
(
M
/
MPerBlock
);
const
index_t
n_block_work_num
=
__builtin_amdgcn_readfirstlane
(
N
/
NPerBlock
);
#endif
#if 0
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const
index_t
m_block_work_id
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
n_block_work_num
);
const
index_t
n_block_work_id
=
get_block_1d_id
()
-
m_block_work_id
*
n_block_work_num
;
#endif
const
index_t
m_block_data_on_global
=
m_block_work_id
*
MPerBlock
;
const
index_t
m_block_data_on_global
=
m_block_work_id
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
n_block_work_id
*
NPerBlock
;
const
index_t
n_block_data_on_global
=
n_block_work_id
*
NPerBlock
;
...
@@ -117,7 +130,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -117,7 +130,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_block_copy
=
auto
a_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v2r
2
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v2r
3
<
BlockSize
,
Float
,
Float
,
Float
,
Float
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_global_desc
),
...
@@ -136,14 +149,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -136,14 +149,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
1
>
(
a_k_m_global_desc
,
1
,
true
,
true
>
(
a_k_m_global_desc
,
make_multi_index
(
0
,
m_block_data_on_global
),
make_multi_index
(
0
,
m_block_data_on_global
),
a_k_m_block_desc
,
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_block_copy
=
auto
b_block_copy
=
BlockwiseDynamicTensorSliceTransfer_v2r
2
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v2r
3
<
BlockSize
,
Float
,
Float
,
Float
,
Float
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_global_desc
),
...
@@ -162,7 +178,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -162,7 +178,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
1
>
(
b_k_n_global_desc
,
1
,
#if 0
true.
#else
false
,
#endif
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_on_global
),
make_multi_index
(
0
,
n_block_data_on_global
),
b_k_n_block_desc
,
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
...
@@ -230,12 +253,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -230,12 +253,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr
auto
b_block_slice_copy_step
=
b_block_copy
.
threadwise_read_
.
GetCoordinateStepBack
()
+
make_multi_index
(
KPerBlock
,
0
);
#endif
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_block_copy
.
Run
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_block_desc
,
p_a_block_double
);
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
b_block_copy
.
Run
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_block_desc
,
p_b_block_double
);
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
,
p_b_thread_buffer
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -262,16 +298,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -262,16 +298,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads
();
__syncthreads
();
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
block_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
,
p_b_thread_buffer
);
}
}
}
}
...
@@ -284,16 +323,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
...
@@ -284,16 +323,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads
();
__syncthreads
();
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_block_copy
.
RunWrite
(
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
,
p_b_thread_buffer
);
__syncthreads
();
__syncthreads
();
...
@@ -411,7 +455,7 @@ template <index_t BlockSize,
...
@@ -411,7 +455,7 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseDynamicGemm_km_kn_mn_v
1r
2
struct
GridwiseDynamicGemm_km_kn_mn_v2
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
...
@@ -437,18 +481,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -437,18 +481,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr
index_t
b_block_space_size
=
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
IsEvenNumberKBlockLoop
>
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
Float
*
__restrict__
p_shared_block
)
const
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -612,8 +655,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -612,8 +655,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr
index_t
b_block_space_size
=
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
Float
*
p_a_block
_double
=
p_shared_block
;
Float
*
p_a_block
=
p_shared_block
;
Float
*
p_b_block
_double
=
p_shared_block
+
2
*
a_block_space_size
;
Float
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
...
@@ -631,7 +674,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -631,7 +674,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
b_block_copy
.
threadwise_read_
.
GetCoordinateStepBack
()
+
make_multi_index
(
KPerBlock
,
0
);
b_block_copy
.
threadwise_read_
.
GetCoordinateStepBack
()
+
make_multi_index
(
KPerBlock
,
0
);
#endif
#endif
//
LDS double buffer:
preload data into LDS
// preload data into LDS
{
{
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
...
@@ -639,89 +682,41 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -639,89 +682,41 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block
_double
,
p_a_thread_buffer
);
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block
_double
,
p_b_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block
,
p_b_thread_buffer
);
}
}
// LDS double buffer: main body
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
2
*
KPerBlock
;
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
)
k_block_data_begin
+=
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space_size
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space_size
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space_size
:
p_b_block_double
;
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS doubel buffer: load next data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
block_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_next
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_next
,
p_b_thread_buffer
);
}
}
// LDS double buffer: tail
{
if
constexpr
(
IsEvenNumberKBlockLoop
)
// if has 2 iteration left
{
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
a_block_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
// load next data from device mem
Float
p_a_thread_buffer
[
a_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
Float
p_b_thread_buffer
[
b_block_copy
.
thread_buffer_desc_
.
GetElementSpaceSize
()];
// LDS double buffer: load last data from device mem
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
a_block_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
p_a_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
b_block_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
__syncthreads
();
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// GEMM on current data
a_block_copy
.
RunWrite
(
block_gemm
.
Run
(
p_a_block
,
p_b_block
,
p_c_thread
);
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
,
p_a_thread_buffer
);
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
,
p_b_thread_buffer
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// store next data to LDS
block_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
a_block_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block
,
p_a_thread_buffer
);
p_b_block_double
+
b_block_space_size
,
b_block_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block
,
p_b_thread_buffer
);
p_c_thread
);
}
}
else
// if has 1 iteration left
// tail
{
{
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
block_gemm
.
Run
(
p_a_block
,
p_b_block
,
p_c_thread
);
block_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
}
// output: register to global memory
// output: register to global memory
...
@@ -769,14 +764,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -769,14 +764,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
}
}
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
IsEvenNumberKBlockLoop
>
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
)
const
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
)
const
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
...
@@ -788,8 +782,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -788,8 +782,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
p_shared_block
,
p_shared_block
);
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
{});
}
}
};
};
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
70d06fa9
...
@@ -255,15 +255,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
...
@@ -255,15 +255,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr
index_t
Len0
=
SliceLengths
{}[
0
];
constexpr
index_t
Len0
=
SliceLengths
{}[
0
];
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
constexpr
index_t
Len1
=
SliceLengths
{}[
1
];
bool
forward_dim0
=
true
;
bool
forward_dim1
=
true
;
#pragma unroll
#pragma unroll
for
(
index_t
i0
=
0
;
i0
<
Len0
;
++
i0
)
for
(
index_t
i0
=
0
;
i0
<
Len0
;
++
i0
)
{
{
#pragma unroll
#pragma unroll
for
(
index_t
i1
=
0
;
i1
<
Len1
;
++
i1
)
for
(
index_t
i1
=
0
;
i1
<
Len1
;
++
i1
)
{
{
#if 1 // debug
// do work
// do work
transfer_data
<
SrcData
,
transfer_data
<
SrcData
,
1
,
1
,
...
@@ -282,10 +280,69 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
...
@@ -282,10 +280,69 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
),
dst_desc
,
dst_slice_origin_
),
dst_desc
.
GetElementSpaceSize
());
dst_desc
.
GetElementSpaceSize
());
#else
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
&&
DstAddressSpace
==
AddressSpace
::
Vgpr
)
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
))
{
const
SrcData
tmp
=
amd_buffer_load
<
SrcData
,
1
>
(
p_src
,
src_slice_origin_
.
GetOffset
(),
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
),
src_desc
.
GetElementSpaceSize
());
const
index_t
dst_offset
=
dst_slice_origin_
.
GetOffset
();
p_dst
[
dst_offset
]
=
tmp
;
}
}
else
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Vgpr
&&
DstAddressSpace
==
AddressSpace
::
Global
)
{
const
SrcData
zeros
=
0
;
const
bool
src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
);
const
bool
dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
);
amd_buffer_store
<
SrcData
,
1
>
(
src_valid
?
&
(
p_src
[
src_slice_origin_
.
GetOffset
()])
:
&
zeros
,
p_dst
,
dst_slice_origin_
.
GetOffset
(),
dst_valid
,
dst_desc
.
GetElementSpaceSize
());
}
else
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_slice_origin_
))
{
if
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_
))
{
p_dst
[
dst_slice_origin_
.
GetOffset
()]
=
p_src
[
src_slice_origin_
.
GetOffset
()];
}
else
{
p_dst
[
dst_slice_origin_
.
GetOffset
()]
=
0
;
}
}
}
#endif
// move dim1 iterator
// move dim1 iterator
if
(
i1
<
Len1
-
1
)
if
(
i1
<
Len1
-
1
)
{
{
bool
forward_dim1
=
(
i0
%
2
==
0
);
if
(
forward_dim1
)
if
(
forward_dim1
)
{
{
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
...
@@ -303,23 +360,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
...
@@ -303,23 +360,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
}
}
// switch dim1 iteration direction
forward_dim1
=
!
forward_dim1
;
// move dim0 iterator
// move dim0 iterator
if
(
i0
<
Len0
-
1
)
if
(
i0
<
Len0
-
1
)
{
if
(
forward_dim0
)
{
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_p1_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_p1_0
);
}
}
else
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_
,
src_step_m1_0
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
,
dst_step_m1_0
);
}
}
}
}
}
}
else
if
constexpr
(
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
()
==
4
)
else
if
constexpr
(
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
()
==
4
)
...
...
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
View file @
70d06fa9
...
@@ -185,25 +185,26 @@ __device__ void transfer_data(const T* p_src,
...
@@ -185,25 +185,26 @@ __device__ void transfer_data(const T* p_src,
"wrong! InMemoryDataOperation not supported!");
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1)
if constexpr(SrcDataStride == 1 && DstDataStride == 1)
{
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto)
{
{
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
);
}
}
}
else
else
{
{
#pragma unroll
for(index_t i = 0; i < DataPerAccess; ++i)
for(index_t i = 0; i < DataPerAccess; ++i)
{
{
// TODO: use static_if::ElseIf
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto)
{
{
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
p_src,
src_offset + i * SrcDataStride,
src_offset + i * SrcDataStride,
...
@@ -213,9 +214,9 @@ __device__ void transfer_data(const T* p_src,
...
@@ -213,9 +214,9 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride,
dst_offset + i * DstDataStride,
dst_valid,
dst_valid,
dst_range);
dst_range);
}
);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto)
{
{
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
p_src,
src_offset + i * SrcDataStride,
src_offset + i * SrcDataStride,
...
@@ -225,7 +226,7 @@ __device__ void transfer_data(const T* p_src,
...
@@ -225,7 +226,7 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride,
dst_offset + i * DstDataStride,
dst_valid,
dst_valid,
dst_range);
dst_range);
}
);
}
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment