Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8d4b51ca
Commit
8d4b51ca
authored
May 19, 2022
by
ltqin
Browse files
add test code
parent
4f88629d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
14 deletions
+129
-14
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+122
-7
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
8d4b51ca
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if NORMAL_CONFIG
#if NORMAL_CONFIG
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
6
4
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
25
6
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
#else
#else
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
7
,
1
>
;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
7
,
1
>
;
#endif
#endif
...
@@ -84,13 +84,13 @@ int main(int argc, char* argv[])
...
@@ -84,13 +84,13 @@ int main(int argc, char* argv[])
// GEMM shape
// GEMM shape
#if NORMAL_CONFIG
#if NORMAL_CONFIG
ck
::
index_t
M
=
64
;
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
6
4
;
ck
::
index_t
K
=
4
096
;
ck
::
index_t
StrideA
=
6
4
;
ck
::
index_t
StrideA
=
4
096
;
ck
::
index_t
StrideB
=
6
4
;
ck
::
index_t
StrideB
=
4
096
;
ck
::
index_t
StrideC
=
128
;
ck
::
index_t
StrideC
=
4096
;
#else
#else
ck
::
index_t
M
=
16
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
N
=
16
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
8d4b51ca
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#define USING_SKIP_LDS 1
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -263,8 +263,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -263,8 +263,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
constexpr
auto
a_block_space_size_aligned
=
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
#if USING_SKIP_LDS
constexpr
auto
b_block_space_size_aligned
=
0
;
#else
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
#endif
return
a_block_space_size_aligned
*
sizeof
(
FloatAB
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -510,6 +517,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -510,6 +517,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
...
@@ -546,8 +556,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -546,8 +556,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_block_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
#if USING_SKIP_LDS
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
max_lds_align
;
// B matrix blockwise copy
// B matrix blockwise copy
constexpr
auto
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
constexpr
auto
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
...
@@ -603,7 +614,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -603,7 +614,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
Number
<
K1
>
{}
>
,
Number
<
K1
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
7
,
1
,
ABlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
@@ -659,12 +670,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -659,12 +670,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// a data write to lds
// a data write to lds
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
// main body
// main body
if
constexpr
(
HasMainK0BlockLoop
)
if
constexpr
(
HasMainK0BlockLoop
)
{
{
index_t
i
=
0
;
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
)
;
index_t
i
=
0
;
do
do
{
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
...
@@ -697,6 +707,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -697,6 +707,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
,
c_thread_buf
);
}
}
}
}
#else
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
#endif
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment