Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
071ca121
Commit
071ca121
authored
May 18, 2022
by
ltqin
Browse files
fix k0perthread and gridewis gemm main loop
parent
2159921e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
33 deletions
+45
-33
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
+9
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+34
-25
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
071ca121
...
...
@@ -43,7 +43,7 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
#define NORMAL_CONFIG
0
#define NORMAL_CONFIG
1
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
...
...
@@ -229,7 +229,7 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
#if
1
#if
0
{
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
View file @
071ca121
...
...
@@ -41,6 +41,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
...
@@ -278,7 +280,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
k
/
KPack
,
0
,
n0
,
0
,
0
,
i
))
>
{}];
make_tuple
(
0
,
0
,
k
/
KPack
,
0
,
n0
,
0
,
0
,
i
))
>
{}];
});
using
mfma_input_type
=
...
...
@@ -304,11 +306,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
// KPerThread
I1
,
// NBlockId
Number
<
NRepeat
>
{},
// repeat
I1
,
// waves
I1
,
// NPerXdlops
I1
,
Number
<
K0PerThread
>
{},
// KPerThread
I1
,
// NBlockId
Number
<
NRepeat
>
{},
// repeat
I1
,
// waves
I1
,
// NPerXdlops
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
071ca121
...
...
@@ -207,7 +207,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXDL
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
...
...
@@ -359,12 +359,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const
auto
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
K0PerBlock
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
,
5
>
{},
Sequence
<
6
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
return
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
}
...
...
@@ -383,7 +384,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
__device__
static
auto
GetWaveKNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_nk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0PerThread
,
NPerXDL
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
xdlops_gemm
.
K0PerXdlops
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
...
...
@@ -559,7 +560,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// B matrix blockwise copy
constexpr
auto
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
K0PerThread
>
{},
// K0PerThread
I1
,
Number
<
K0PerThread
>
{},
// K0PerThread
I1
,
// NBlockId
Number
<
NXdlPerWave
>
{},
// repeat
I1
,
// waves
...
...
@@ -590,21 +592,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k0b_n0_n1_n2_n3_k1.GetLength(I0));
#endif
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
),
decltype
(
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1
),
Sequence
<
I1
,
Number
<
K0PerThread
>
{},
I1
,
Number
<
NXdlPerWave
>
{},
I1
,
I1
,
Number
<
K1
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
,
make_multi_index
(
0
,
wave_k_n_id
[
I0
],
block_work_idx
[
I1
],
0
,
wave_id
[
I1
],
wave_k_n_id
[
I1
],
0
));
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
),
decltype
(
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1
),
Sequence
<
I1
,
I1
,
Number
<
K0PerThread
>
{},
I1
,
Number
<
NXdlPerWave
>
{},
I1
,
I1
,
Number
<
K1
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
,
make_multi_index
(
0
,
wave_k_n_id
[
I0
],
0
,
block_work_idx
[
I1
],
0
,
wave_id
[
I1
],
wave_k_n_id
[
I1
],
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -634,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// gridwise GEMM pipeline
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
{
// Read
...
...
@@ -642,7 +653,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
,
b_grid_buf
,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// Move
...
...
@@ -666,15 +677,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
,
c_thread_buf
);
// read b after gemm
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1
,
b_grid_buf
,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
,
c_thread_buf
);
block_sync_lds
();
// move windows
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment