Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
80e89ebd
Commit
80e89ebd
authored
Jan 24, 2025
by
aska-0096
Browse files
minimum reproducable example for warpspecialized scheduling
parent
af30d6b6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
499 additions
and
229 deletions
+499
-229
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
...y_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
+496
-226
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
View file @
80e89ebd
...
...
@@ -143,11 +143,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
32
,
32
,
4
,
4
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
View file @
80e89ebd
...
...
@@ -11,7 +11,7 @@ namespace ck {
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer:
1
// LocalSharedMemoryBuffer:
2
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
...
...
@@ -145,10 +145,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using
Base
::
MWaves
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
HotloopLocalBufSwitch
=
MRepeat
%
2
==
0
?
0
:
1
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
a_local_write_issue_stage
=
NPerXDL
==
32
?
1
:
2
;
static
constexpr
index_t
a_global_read_issue_stage
=
NPerXDL
==
32
?
2
:
4
;
static
constexpr
index_t
a_global_read_issue_stage_end
=
NPerXDL
==
32
?
3
:
6
;
template
<
typename
TileDesc_M0_M1_M2_K
>
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
...
...
@@ -187,7 +190,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
template
<
typename
Stage
>
__device__
static
constexpr
auto
HotLoopScheduler
(
Stage
stage
)
{
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
;
constexpr
auto
num_ds_read_grouped
=
KPack
/
A_K1
;
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
num_ds_read_grouped
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
MWaves
*
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
...
...
@@ -199,12 +204,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr
auto
staged_num_mfma_per_ds_read_a
=
staged_num_mfma
/
staged_num_ds_read_inst_a
;
if
constexpr
(
stage
.
value
==
0
)
if
constexpr
(
stage
.
value
<
a_local_write_issue_stage
)
{
constexpr
auto
issue_stages
=
a_local_write_issue_stage
;
constexpr
auto
staged_num_buffer_load_b_per_ds_read_a
=
num_buffer_load_inst_b
/
staged_num_ds_read_inst_a
;
num_buffer_load_inst_b
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_buffer_load_b
=
staged_num_mfma
/
num_buffer_load_inst_b
;
issue_stages
*
staged_num_mfma
/
num_buffer_load_inst_b
;
// B global
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
...
...
@@ -216,129 +223,105 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_b
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
0x008
,
staged_num_mfma_per_buffer_load_b
-
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
else
if
constexpr
(
stage
.
value
==
1
)
else
if
constexpr
(
stage
.
value
<
a_global_read_issue_stage
)
{
constexpr
auto
issue_stages
=
a_global_read_issue_stage
-
a_local_write_issue_stage
;
constexpr
auto
staged_num_ds_write_a_per_ds_read_a
=
num_ds_write_inst_a
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_ds_write_a
=
math
::
integer_divide_ceil
(
staged_num_mfma
,
num_ds_write_inst_a
);
issue_stages
*
staged_num_mfma
/
num_ds_write_inst_a
;
// A local write
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
constexpr
auto
stage_more_mfma
=
staged_num_mfma
-
(
staged_num_mfma_per_ds_write_a
-
1
)
*
num_ds_write_inst_a
;
static_for
<
0
,
staged_num_ds_write_a_per_ds_read_a
,
1
>
{}([
&
](
auto
idswrite_inst
)
{
ignore
=
idswrite_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
});
// A local write
static_for
<
0
,
num_ds_write_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
if
constexpr
(
i_inst
.
value
<
stage_more_mfma
)
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
}
}
else
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
}
}
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_ds_write_a_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
else
if
constexpr
(
stage
.
value
==
2
)
else
if
constexpr
(
stage
.
value
<
a_global_read_issue_stage_end
)
{
constexpr
auto
issue_stages
=
a_global_read_issue_stage_end
-
a_global_read_issue_stage
;
constexpr
auto
staged_num_buffer_load_a_per_ds_read_a
=
num_buffer_load_inst_a
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_buffer_load_a
=
math
::
integer_divide_ceil
(
staged_num_mfma
,
num_buffer_load_inst_a
);
issue_stages
*
staged_num_mfma
/
num_buffer_load_inst_a
;
// A global
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
staged_num_buffer_load_a_per_ds_read_a
-
1
,
1
>
{}([
&
](
auto
ibuf_inst
)
{
ignore
=
ibuf_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
constexpr
auto
stage_more_mfma
=
staged_num_mfma
-
(
staged_num_mfma_per_buffer_load_a
-
1
)
*
num_buffer_load_inst_a
;
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
// A global
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
if
constexpr
(
i_inst
.
value
<
stage_more_mfma
)
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
}
}
else
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
-
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
}
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
-
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
}
template
<
typename
Stage
>
__device__
static
constexpr
auto
EpilogueScheduler_1
(
Stage
stage
)
{
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
;
constexpr
auto
num_ds_read_grouped
=
KPack
/
A_K1
;
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
num_ds_read_grouped
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
MWaves
*
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
...
...
@@ -349,38 +332,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr
auto
staged_num_mfma_per_ds_read_a
=
staged_num_mfma
/
staged_num_ds_read_inst_a
;
if
constexpr
(
stage
.
value
==
0
)
if
constexpr
(
stage
.
value
<
a_local_write_issue_stage
)
{
constexpr
auto
issue_stages
=
a_local_write_issue_stage
;
constexpr
auto
staged_num_buffer_load_b_per_ds_read_a
=
num_buffer_load_inst_b
/
staged_num_ds_read_inst_a
;
num_buffer_load_inst_b
/
(
a_local_write_issue_stage
*
issue_stages
)
;
constexpr
auto
staged_num_mfma_per_buffer_load_b
=
staged_num_mfma
/
num_buffer_load_inst_b
;
issue_stages
*
staged_num_mfma
/
num_buffer_load_inst_b
;
// B global
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
staged_num_buffer_load_b_per_ds_read_a
,
1
>
{}([
&
](
auto
ibuf_inst
)
{
static_for
<
0
,
staged_num_buffer_load_b_per_ds_read_a
-
1
,
1
>
{}([
&
](
auto
ibuf_inst
)
{
ignore
=
ibuf_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_b
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_b
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
0x008
,
staged_num_mfma_per_buffer_load_b
-
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
else
if
constexpr
(
stage
.
value
==
1
)
else
if
constexpr
(
stage
.
value
<
a_global_read_issue_stage
)
{
#if 0
constexpr
auto
issue_stages
=
a_global_read_issue_stage
-
a_local_write_issue_stage
;
constexpr
auto
staged_num_ds_write_a_per_ds_read_a
=
num_ds_write_inst_a / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
num_ds_write_inst_a
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_ds_write_a
=
issue_stages
*
staged_num_mfma
/
num_ds_write_inst_a
;
// A local write
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
...
...
@@ -392,74 +383,41 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_ds_write_a_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
});
#elif
1
constexpr
auto
staged_num_mfma_per_ds_write_a
=
math
::
integer_divide_ceil
(
staged_num_mfma
,
num_ds_write_inst_a
);
constexpr
auto
stage_more_mfma
=
staged_num_mfma
-
(
staged_num_mfma_per_ds_write_a
-
1
)
*
num_ds_write_inst_a
;
// A local write
static_for
<
0
,
num_ds_write_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
if
constexpr
(
i_inst
.
value
<
stage_more_mfma
)
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
}
}
else
{
if
(
i_inst
.
value
<
staged_num_ds_read_inst_a
)
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
}
}
});
#endif
__builtin_amdgcn_sched_barrier
(
0
);
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
});
__builtin_amdgcn_sched_barrier
(
0
);
//
__builtin_amdgcn_sched_barrier(0);
}
}
__device__
static
constexpr
auto
EpilogueScheduler_2
()
{
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
;
constexpr
auto
num_ds_read_grouped
=
KPack
/
A_K1
;
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
num_ds_read_grouped
;
constexpr
auto
num_mfma
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
...
...
@@ -471,13 +429,147 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// A local Read
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
});
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
typename
Stage
>
__device__
static
constexpr
auto
HotLoopScheduler_B
(
Stage
stage
)
{
constexpr
auto
num_ds_read_grouped
=
KPack
/
A_K1
;
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
num_ds_read_grouped
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
MWaves
*
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
staged_num_ds_read_inst_a
=
num_ds_read_inst_a
/
MRepeat
;
constexpr
auto
staged_num_mfma
=
num_mfma
/
MRepeat
;
constexpr
auto
staged_num_mfma_per_ds_read_a
=
staged_num_mfma
/
staged_num_ds_read_inst_a
;
if
constexpr
(
stage
.
value
<
a_local_write_issue_stage
)
{
constexpr
auto
issue_stages
=
a_local_write_issue_stage
;
constexpr
auto
staged_num_buffer_load_b_per_ds_read_a
=
num_buffer_load_inst_b
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_buffer_load_b
=
issue_stages
*
staged_num_mfma
/
num_buffer_load_inst_b
;
// B global
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
staged_num_buffer_load_b_per_ds_read_a
-
1
,
1
>
{}([
&
](
auto
ibuf_inst
)
{
ignore
=
ibuf_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_b
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_b
-
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
// __builtin_amdgcn_sched_barrier(0);
}
else
if
constexpr
(
stage
.
value
<
a_global_read_issue_stage
)
{
constexpr
auto
issue_stages
=
a_global_read_issue_stage
-
a_local_write_issue_stage
;
constexpr
auto
staged_num_ds_write_a_per_ds_read_a
=
num_ds_write_inst_a
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_ds_write_a
=
issue_stages
*
staged_num_mfma
/
num_ds_write_inst_a
;
// A local write
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
staged_num_ds_write_a_per_ds_read_a
,
1
>
{}([
&
](
auto
idswrite_inst
)
{
ignore
=
idswrite_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS Write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_write_a
-
1
,
0
);
// MFMA
});
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_ds_write_a_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
});
});
// __builtin_amdgcn_sched_barrier(0);
}
else
if
constexpr
(
stage
.
value
<
a_global_read_issue_stage_end
)
{
constexpr
auto
issue_stages
=
a_global_read_issue_stage_end
-
a_global_read_issue_stage
;
constexpr
auto
staged_num_buffer_load_a_per_ds_read_a
=
num_buffer_load_inst_a
/
staged_num_ds_read_inst_a
/
issue_stages
;
constexpr
auto
staged_num_mfma_per_buffer_load_a
=
issue_stages
*
staged_num_mfma
/
num_buffer_load_inst_a
;
// A global
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
staged_num_buffer_load_a_per_ds_read_a
-
1
,
1
>
{}([
&
](
auto
ibuf_inst
)
{
ignore
=
ibuf_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_buffer_load_a
-
num_ds_read_grouped
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
});
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for
<
0
,
staged_num_ds_read_inst_a
,
1
>
{}([
&
](
auto
i_inst
)
{
ignore
=
i_inst
;
static_for
<
0
,
num_ds_read_grouped
,
1
>
{}([
&
](
auto
ids_inst
)
{
ignore
=
ids_inst
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
staged_num_mfma_per_ds_read_a
/
num_ds_read_grouped
,
0
);
// MFMA
});
});
// __builtin_amdgcn_sched_barrier(0);
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
...
...
@@ -551,6 +643,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier
(
0
);
// 0: Warp specialized scheduling
// 1: unique scheduling
#if 0
// main body
if constexpr(HasMainLoop)
{
...
...
@@ -558,21 +653,20 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
if
constexpr
(
m0
.
value
==
0
)
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
}
else
if
constexpr
(
m0
.
value
==
1
)
if constexpr(m0.value == a_local_write_issue_stage)
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
local_read_buf
));
a_blockwise_copy.RunWrite(a_block_desc,
a_block_buf.At(local_read_buf));
}
else
if
constexpr
(
m0
.
value
==
2
)
else if constexpr(m0.value ==
a_global_read_issue_stage
)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
...
...
@@ -586,13 +680,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple
((
m0
+
HotloopLocalBufSwitch
*
mfma_reg_buf
)
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
...
...
@@ -620,17 +708,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple
(
Number
<
(
m0
+
1
+
HotloopLocalBufSwitch
*
mfma_reg_buf
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
...
...
@@ -639,17 +721,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple
(
Number
<
(
m0
+
1
+
HotloopLocalBufSwitch
*
mfma_reg_buf
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
...
...
@@ -660,23 +736,219 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
#elif
1
const
index_t
warp_id
=
__builtin_amdgcn_readfirstlane
(
get_warp_local_1d_id
());
if
(
warp_id
<
2
)
{
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
auto
LoopFunc
=
[
&
](
auto
mfma_reg_buf
,
auto
local_read_buf
)
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
if
constexpr
(
m0
.
value
==
a_local_write_issue_stage
)
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
local_read_buf
));
}
else
if
constexpr
(
m0
.
value
==
a_global_read_issue_stage
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
if
constexpr
(
m0
.
value
==
MRepeat
-
1
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
}
else
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
mfma_reg_buf
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
}
HotLoopScheduler
(
m0
);
});
};
LoopFunc
(
I0
,
I1
);
LoopFunc
(
I1
,
I0
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
}
else
{
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
auto
LoopFunc
=
[
&
](
auto
mfma_reg_buf
,
auto
local_read_buf
)
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
if
constexpr
(
m0
.
value
==
a_local_write_issue_stage
)
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
local_read_buf
));
}
else
if
constexpr
(
m0
.
value
==
a_global_read_issue_stage
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
if
constexpr
(
m0
.
value
==
MRepeat
-
1
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
}
else
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
(
m0
+
1
)
%
MRepeat
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
mfma_reg_buf
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
}
HotLoopScheduler_B
(
m0
);
});
};
LoopFunc
(
I0
,
I1
);
LoopFunc
(
I1
,
I0
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
}
#endif
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
if
constexpr
(
m0
.
value
==
0
)
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
));
}
else
if
constexpr
(
m0
.
value
==
MRepeat
-
1
)
if
constexpr
(
m0
.
value
==
a_local_write_issue_stage
)
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
}
...
...
@@ -745,8 +1017,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
(
m0
+
HotloopLocalBufSwitch
)
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
...
...
@@ -767,14 +1039,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
if
constexpr
(
m0
.
value
!=
(
MRepeat
-
1
))
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
m0
+
1
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
I1
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
+
HotloopLocalBufSwitch
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
Number
<
m0
+
1
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
I1
),
a_thread_desc_
,
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
EpilogueScheduler_2
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment