Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0d02519a
"vscode:/vscode.git/clone" did not exist on "ead87d72c2c51c6dc9acb71b9ac971a989176a69"
Commit
0d02519a
authored
Jun 09, 2022
by
wangshaojie6
Browse files
add multik0
parent
4bda6db0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
219 deletions
+65
-219
example/01_gemm/gemm_xdl_fp16_splitk.cpp
example/01_gemm/gemm_xdl_fp16_splitk.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
...eration/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
+64
-218
No files found.
example/01_gemm/gemm_xdl_fp16_splitk.cpp
View file @
0d02519a
...
...
@@ -49,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
2
>
;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
8
,
16
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
4
,
true
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
2
>
;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
View file @
0d02519a
...
...
@@ -49,7 +49,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
//p_shared,
//
p_shared,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
...
@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
MultiK0
=
4
*
1
;
static
constexpr
auto
MultiK0
=
8
*
1
;
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -238,15 +238,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1
=
transform_tensor_descriptor
(
a_grid_desc_k0_m_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
M
/
(
MXdlPerWave
*
MWaves
*
MPerXDL
),
MXdlPerWave
,
MWaves
,
MPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1
=
transform_tensor_descriptor
(
a_grid_desc_k0_m_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
M
/
(
MXdlPerWave
*
MWaves
*
MPerXDL
),
MXdlPerWave
,
MWaves
,
MPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
return
a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1
;
}
...
...
@@ -256,15 +257,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
return
b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1
;
}
...
...
@@ -399,7 +401,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
//void* __restrict__ p_shared,
//
void* __restrict__ p_shared,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_K1_K2_M0_M1_M2_M3_K3
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
const
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
...
@@ -446,8 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
FloatAB
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
.
GetElementSpaceSize
(),
true
>
a_thread_buf_0
,
a_thread_buf_1
,
a_thread_buf_2
,
a_thread_buf_3
;
a_thread_buf
[
MultiK0
];
//, a_thread_buf_1, a_thread_buf_2, a_thread_buf_3;
ignore
=
b_element_op
;
// B matrix threadwise copy
...
...
@@ -465,7 +466,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
FloatAB
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
b_thread_buf_0
,
b_thread_buf_1
,
b_thread_buf_2
,
b_thread_buf_3
;
b_thread_buf
[
MultiK0
];
//
_0, b_thread_buf_1, b_thread_buf_2, b_thread_buf_3;
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
...
...
@@ -513,7 +514,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
make_multi_index
(
0
,
wave_k_m_id
[
I0
],
0
,
block_work_idx
[
I0
],
0
,
wave_id
[
I1
],
wave_k_m_id
[
I1
],
0
));
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
...
...
@@ -561,250 +561,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// gridwise GEMM pipeline
//constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
//
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr
auto
a_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
{
// Read
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_0
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_0
);
// Move
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// Read
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_1
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_1
);
// Move
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// Read
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_2
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_2
);
// Move
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_3
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_3
);
index_t
i_pre
=
0
;
do
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i_pre
]);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
[
i_pre
]);
// Move
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
i_pre
++
;
}
while
(
i_pre
<
MultiK0
);
// Initialize C
c_thread_buf
.
Clear
();
// a data write to lds
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
(
MultiK0
*
K0PerBlock
)
)
;
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
index_t
i
=
0
;
do
{
index_t
i_k
=
0
;
do
{
blockwise_gemm
.
Run
(
a_thread_buf_0
,
b_thread_buf_0
,
c_thread_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_0
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_0
);
blockwise_gemm
.
Run
(
a_thread_buf_1
,
b_thread_buf_1
,
c_thread_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_thread_buf
[
i_k
],
b_thread_buf
[
i_k
],
c_thread_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
_1
);
b_thread_buf
[
i_k
]
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
_1
);
a_thread_buf
[
i_k
]
);
blockwise_gemm
.
Run
(
a_thread_buf_2
,
b_thread_buf_2
,
c_thread_buf
);
asm
volatile
(
"s_nop 0"
::
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
i_k
++
;
}
while
(
i_k
<
MultiK0
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_2
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_2
);
blockwise_gemm
.
Run
(
a_thread_buf_3
,
b_thread_buf_3
,
c_thread_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_3
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_3
);
}
i
+=
1
;
}
while
(
i
<
(
K0BlockMainLoop
-
1
));
i
+=
MultiK0
;
}
while
(
i
<
(
K0BlockMainLoop
-
MultiK0
));
}
// tail
{
static_for
<
0
,
MultiK0
,
4
>
{}([
&
](
auto
i
)
{
blockwise_gemm
.
Run
(
a_thread_buf_0
,
b_thread_buf_0
,
c_thread_buf
);
if
constexpr
(
i
<
MultiK0
-
4
)
{
// only move b windows
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_0
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_0
);
}
blockwise_gemm
.
Run
(
a_thread_buf_1
,
b_thread_buf_1
,
c_thread_buf
);
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_gemm
.
Run
(
a_thread_buf
[
i
],
b_thread_buf
[
i
],
c_thread_buf
);
if
constexpr
(
i
<
MultiK0
-
4
)
{
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_1
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_1
);
}
blockwise_gemm
.
Run
(
a_thread_buf_2
,
b_thread_buf_2
,
c_thread_buf
);
if
constexpr
(
i
<
MultiK0
-
4
)
{
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
_2
);
b_thread_buf
[
i
]
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_2
);
}
blockwise_gemm
.
Run
(
a_thread_buf_3
,
b_thread_buf_3
,
c_thread_buf
);
if
constexpr
(
i
<
MultiK0
-
4
)
{
a_thread_buf
[
i
]);
// only move b windows
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf_3
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf_3
);
}
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment