Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f64b1375
Commit
f64b1375
authored
Feb 17, 2025
by
coderfeli
Browse files
merge haocong branch
parents
88412f9e
f18cfec4
Changes
124
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2209 additions
and
2998 deletions
+2209
-2998
Jenkinsfile
Jenkinsfile
+0
-3
example/65_gemm_multiply_multiply/CMakeLists.txt
example/65_gemm_multiply_multiply/CMakeLists.txt
+0
-2
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
...y_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
+134
-142
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
.../blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
+25
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
+107
-119
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
+860
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+0
-9
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+2
-337
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+205
-814
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+2
-280
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+1
-310
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+7
-274
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+194
-93
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+87
-274
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+0
-217
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+514
-90
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
...reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
+34
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+9
-20
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
...erence_tensor_operation/cpu/reference_gemm_multiple_d.hpp
+21
-3
No files found.
Jenkinsfile
View file @
f64b1375
...
@@ -722,9 +722,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
...
@@ -722,9 +722,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
pipeline
{
pipeline
{
agent
none
agent
none
triggers
{
parameterizedCron
(
CRON_SETTINGS
)
}
options
{
options
{
parallelsAlwaysFailFast
()
parallelsAlwaysFailFast
()
}
}
...
...
example/65_gemm_multiply_multiply/CMakeLists.txt
View file @
f64b1375
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp
)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_moe_gemm1 moe_gemm1.cpp
)
add_example_executable
(
example_moe_gemm1 moe_gemm1.cpp
)
...
...
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
View file @
f64b1375
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
View file @
f64b1375
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
namespace
ck
{
namespace
ck
{
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
...
@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
...
@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
>
{};
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
static_assert
(
MRepeat
>=
4
,
"MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"
);
return
BlockwiseGemmXdlops_pipeline_bpreshuffle_v3
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
else
{
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
View file @
f64b1375
...
@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
template
<
typename
TileDesc_M0_M1_M2_K
>
template
<
typename
TileDesc_M0_M1_M2_K
>
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
...
@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
// Global prefetch A1 B1
// Global prefetch A1 B1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_desc_n0_n1_k0_k1
,
...
@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
__builtin_amdgcn_sched_barrier
(
0
);
// // Local prefill A1
// // Local prefill A1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
// // Global prefetch A2
// // Global prefetch A2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Local prefetch A1
// Local prefetch A1
...
@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
mfma_reg_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
View file @
f64b1375
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
namespace
ck
{
namespace
ck
{
// Compute optimized pipeline
// Compute optimized pipeline
// GlobalPrefetchStages:
2
// GlobalPrefetchStages:
3
// LocalPreFillStages: 2
// LocalPreFillStages: 2
// LocalPreFetchStages: 2
// LocalPreFetchStages: 2
// LocalSharedMemoryBuffer: 2
// LocalSharedMemoryBuffer: 2
...
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using
Base
::
AMmaKStride
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
template
<
typename
TileDesc_M0_M1_M2_K
>
template
<
typename
TileDesc_M0_M1_M2_K
>
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
...
@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
__device__
static
constexpr
auto
HotLoopScheduler
()
__device__
static
constexpr
auto
HotLoopScheduler
()
{
{
#if 0
// constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// B global + A local
static_for
<
0
,
num_buffer_load_inst_b
/
2
,
1
>
{}([
&
](
auto
i
)
{
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore
=
i
;
ignore
=
i
;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read B
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read A
});
});
#endif
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
// B global
static_for
<
0
,
num_buffer_load_inst_b
/
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read B
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read A
});
});
// A global
// A global
...
@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
// A local
// A local
static_for
<
0
,
num_ds_read_inst_a
/
2
,
1
>
{}([
&
](
auto
i
)
{
//
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore
=
i
;
//
ignore = i;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
//
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
2
,
0
);
// DS read
//
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
});
//
});
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
...
@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
// Global prefetch A1, B1
// Global prefetch A1, B1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_desc_n0_n1_k0_k1
,
...
@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
//
//
Local prefill A1
// Local prefill A1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
)
,
I0
);
//
//
Global prefetch A2
// Global prefetch A2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Local prefetch A1
// Local prefetch A1
...
@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_block_buf
.
At
(
I0
),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
s
(
I0
)
);
});
});
});
});
// Local prefill A2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
),
I1
);
// // Global prefetch A3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
do
do
{
{
auto
LoopFunc
=
[
&
](
auto
mfma_reg_buf
,
auto
local_read_buf
)
{
auto
LoopFunc
=
[
&
](
auto
mfma_reg_buf
,
auto
local_read_buf
)
{
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
local_read_buf
));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_bufs
(
local_read_buf
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
mfma_reg_buf
),
mfma_reg_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_bufs
[
mfma_reg_buf
]
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
[
Number
<
b_thread_desc_
.
CalculateOffset
(
...
@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
});
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
});
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
};
};
...
@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
LoopFunc
(
I1
,
I0
);
LoopFunc
(
I1
,
I0
);
i
+=
2
;
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
3
));
}
}
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
auto
ReadWriteCompFunc
=
[
&
](
auto
mfma_reg
,
auto
local_read_reg
)
{
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_block_origin_idx
,
b_thread_bufs
(
I1
));
b_thread_bufs
(
local_read_reg
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_reg
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_bufs
(
local_read_reg
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
)
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
mfma_reg
),
mfma_reg
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
};
auto
ReadCompFunc
=
[
&
](
auto
mfma_reg
,
auto
local_read_reg
)
{
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_reg
));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
I1
),
a_block_buf
.
At
(
local_read_reg
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
s
(
local_read_reg
)
);
});
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
});
});
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
HotLoopScheduler
();
//
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
;
else
{
auto
CompFunc
=
[
&
](
auto
mfma_reg
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
...
@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
});
});
});
};
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadCompFunc
(
I0
,
I1
);
CompFunc
(
I1
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I0
,
I1
);
ReadCompFunc
(
I1
,
I0
);
CompFunc
(
I0
);
}
}
}
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
0 → 100644
View file @
f64b1375
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
View file @
f64b1375
...
@@ -11,15 +11,6 @@
...
@@ -11,15 +11,6 @@
namespace
ck
{
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
View file @
f64b1375
...
@@ -155,158 +155,6 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -155,158 +155,6 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
return
TailNumber
::
Full
;
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
AThreadBuffer
,
typename
BThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
AThreadBuffer
&
a_thread_buf_tail
,
BThreadBuffer
&
b_thread_buf_tail
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
a_thread_buf_tail
=
a_thread_buf
;
b_thread_buf_tail
=
b_thread_buf
;
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
AGridDesc
,
...
@@ -480,6 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...
@@ -480,6 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_copy_
;
...
@@ -607,191 +456,6 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -607,191 +456,6 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
return
TailNumber
::
Full
;
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
AThreadBuffer
,
typename
BThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
AThreadBuffer
&
a_thread_buf_tail
,
BThreadBuffer
&
b_thread_buf_tail
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no
// observable negative impact. The desired effect is waves in a workgroup
// executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
// resource from other workgroups and reducing the chance of latency hiding by
// waiting for the rest of the workgroup at the eventual sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
// applying small delays to different wavefronts It is performed
// near the end of MAC cluster to minimize lgkmcnt penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
});
a_thread_buf_tail
=
a_thread_buf
;
b_thread_buf_tail
=
b_thread_buf
;
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
AGridDesc
,
...
@@ -1023,6 +687,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...
@@ -1023,6 +687,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
}
}
}
}
protected:
// K->M loopover
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
f64b1375
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
f64b1375
...
@@ -262,227 +262,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -262,227 +262,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
});
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
AThreadBuffer
,
typename
BThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
AThreadBuffer
&
a_thread_buf_tail
,
BThreadBuffer
&
b_thread_buf_tail
,
index_t
num_loop
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf_tail
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf_tail
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
AGridDesc
,
...
@@ -635,69 +414,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -635,69 +414,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
i
+=
1
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
1
));
}
}
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
@@ -731,6 +452,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -731,6 +452,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
}
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
f64b1375
...
@@ -234,316 +234,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -234,316 +234,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
AThreadBuffer
,
typename
BThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
AThreadBuffer
&
a_thread_buf_tail
,
BThreadBuffer
&
b_thread_buf_tail
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
));
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
));
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
LoopFunc
(
I1
,
I1
,
I0
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
a_thread_buf_tail
=
a_thread_bufs
[
mfma_reg_buf
];
b_thread_buf_tail
=
b_thread_bufs
[
mfma_reg_buf
];
};
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
);
ReadCompFunc
(
I0
,
I0
,
I1
);
CompFunc
(
I0
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadCompFunc
(
I1
,
I1
,
I0
);
CompFunc
(
I1
);
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
AGridDesc
,
...
@@ -873,6 +563,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -873,6 +563,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}
}
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
f64b1375
...
@@ -316,270 +316,6 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -316,270 +316,6 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
AThreadBuffer
,
typename
BThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
AThreadBuffer
&
a_thread_buf_tail
,
BThreadBuffer
&
b_thread_buf_tail
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_loop
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_loop
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_loop
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_loop
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_loop
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_loop
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
auto
LoopFunc
=
[
&
](
auto
vmem_buf
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
if
constexpr
(
k0
==
(
KRepeat
-
1
))
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
vmem_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
vmem_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
vmem_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
}
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_loop
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_loop
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_loop
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_loop
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_loop
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_loop
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
};
LoopFunc
(
I0
);
LoopFunc
(
I1
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
ReadWriteCompFunc
=
[
&
](
auto
vmem_buf
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
if
constexpr
(
k0
==
(
KRepeat
-
1
))
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
vmem_buf
);
block_sync_lds
();
}
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_loop
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_loop
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_loop
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_loop
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_loop
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_loop
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
};
auto
ReadCompFunc
=
[
&
]()
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf_tail
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf_tail
);
});
});
HotLoopScheduler
();
};
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I0
);
ReadWriteCompFunc
(
I1
);
ReadCompFunc
();
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadWriteCompFunc
(
I0
);
ReadCompFunc
();
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
AGridDesc
,
...
@@ -891,18 +627,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -891,18 +627,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
}
}
}
}
protected:
// A[MRepeat, I1, I1, KPack]
// A[MRepeat, I1, I1, KPack]
static
constexpr
auto
a_thread_desc_
loop
=
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
// B[NRepeat, N1, N2, KPack]
// B[NRepeat, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
loop
=
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
loop
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
...
@@ -912,19 +649,15 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -912,19 +649,15 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
loop
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
AThreadCopy
a_thread_copy_loop
{
Base
::
CalculateAThreadOriginDataIndex
()};
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_loop
{
Base
::
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
f64b1375
...
@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
KBatch
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -249,30 +249,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -249,30 +249,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// Tail number always full
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
{
// if(arg.KBatch > 1)
if
(
arg
.
KBatch
>
1
)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// Run(kernel);
// }
// }
// else
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
...
@@ -295,65 +272,199 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -295,65 +272,199 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run
(
kernel
);
Run
(
kernel
);
}
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
// else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
// {
// if(arg.KBatch > 1)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel =
// kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
// {
// const auto kernel =
// kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// Run(kernel);
// }
// }
// else
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel =
// kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
// {
// const auto kernel =
// kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy,
// TailNumber::Even>;
// Run(kernel);
// }
// }
// }
else
else
{
{
throw
std
::
runtime_error
(
"todo: only v1 & v2 support now"
);
throw
std
::
runtime_error
(
"todo: only v1 & v2 support now"
);
}
}
}
}
#if 0
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v3 support now");
}
}
#endif
return
ave_time
;
return
ave_time
;
}
}
...
@@ -406,13 +517,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -406,13 +517,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
void
*
p_sorted_token_ids
,
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_sorted_expert_ids
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
index_t
NumTokens
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -425,13 +533,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -425,13 +533,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
static_cast
<
const
index_t
*>
(
p_sorted_token_ids
),
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
index_t
*>
(
p_sorted_expert_ids
),
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
NumTokens
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -448,8 +553,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -448,8 +553,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
...
@@ -463,16 +567,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -463,16 +567,13 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
index_t
KBatch
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
override
{
{
// assert(0, "no impl");
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
nullptr
,
nullptr
,
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
f64b1375
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
f64b1375
...
@@ -1220,38 +1220,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1220,38 +1220,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
__device__
static
constexpr
auto
EpilogueScheduler
()
{
constexpr
auto
epilogue_tile
=
MPerBlock
*
NPerBlock
*
CShuffleMXdlPerWavePerShuffle
*
CShuffleNXdlPerWavePerShuffle
/
(
MXdlPerWave
*
NXdlPerWave
);
constexpr
auto
num_mfma_inst
=
BlockwiseGemmPipe
::
HotLoopInstList
::
C_MFMA_Inst_Num
*
CShuffleMXdlPerWavePerShuffle
*
CShuffleNXdlPerWavePerShuffle
/
(
MXdlPerWave
*
NXdlPerWave
);
constexpr
auto
num_ds_write_inst
=
epilogue_tile
/
BlockSize
;
// DefaultMFMA, per-element write
constexpr
auto
num_ds_read_inst
=
epilogue_tile
/
BlockSize
/
CShuffleBlockTransferScalarPerVector_NPerBlock
;
constexpr
auto
num_buffer_store_inst
=
num_ds_read_inst
;
// MFMA:ds_write=1:2
constexpr
auto
num_ds_write_issue
=
num_ds_write_inst
/
2
;
constexpr
auto
num_mfma_block_sync
=
(
num_mfma_inst
-
num_ds_write_issue
)
/
2
;
constexpr
auto
mfma_ds_write_rate
=
MXdlPerWave
==
16
?
2
:
4
;
// Hide ds_write issue latency
static_for
<
0
,
num_ds_write_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
mfma_ds_write_rate
,
0
);
// DS write
});
// Hide block_sync + ds_read latency
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_block_sync
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
,
0
);
// DS read
// Hide block_sync latency
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_block_sync
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x040
,
num_buffer_store_inst
,
0
);
// VMEM write
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
// if arch = gfx942
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
...
@@ -1429,15 +1397,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1429,15 +1397,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
constexpr
auto
a_thread_desc
=
blockwise_gemm_pipeline
.
a_thread_desc_
;
constexpr
auto
b_thread_desc
=
blockwise_gemm_pipeline
.
b_thread_desc_
;
constexpr
auto
c_thread_desc
=
blockwise_gemm_pipeline
.
c_thread_desc_
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
a_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
b_thread_desc
.
GetElementSpaceSize
());
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
...
@@ -1455,16 +1414,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1455,16 +1414,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
a_thread_buf
,
b_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
// Last block MFMA
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
constexpr
auto
KRepeat
=
blockwise_gemm_pipeline
.
KRepeat
;
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -1624,9 +1577,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1624,9 +1577,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
// C: LDS -> VGPR
// D: Global -> VGPR
// E: =Epilogue(C, D), VGPR -> Global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
@@ -1685,84 +1635,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1685,84 +1635,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
shuffle_m0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
0
>
{}];
constexpr
auto
shuffle_n0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
1
>
{}];
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
}
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
@@ -1796,8 +1672,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1796,8 +1672,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
// EpilogueScheduler();
}
}
});
});
}
}
...
@@ -1990,15 +1864,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1990,15 +1864,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
constexpr
auto
a_thread_desc
=
blockwise_gemm_pipeline
.
a_thread_desc_
;
constexpr
auto
b_thread_desc
=
blockwise_gemm_pipeline
.
b_thread_desc_
;
constexpr
auto
c_thread_desc
=
blockwise_gemm_pipeline
.
c_thread_desc_
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
a_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
b_thread_desc
.
GetElementSpaceSize
());
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
...
@@ -2016,16 +1881,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2016,16 +1881,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_bufs
,
b_block_bufs
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
a_thread_buf
,
b_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
// Last block MFMA
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
constexpr
auto
KRepeat
=
blockwise_gemm_pipeline
.
KRepeat
;
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -2243,84 +2102,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2243,84 +2102,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
shuffle_m0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
0
>
{}];
constexpr
auto
shuffle_n0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
1
>
{}];
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
}
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
@@ -2354,8 +2139,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2354,8 +2139,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
// EpilogueScheduler();
}
}
});
});
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
f64b1375
This diff is collapsed.
Click to expand it.
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
View file @
f64b1375
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -73,9 +73,39 @@ struct ReferencefpAintBGemm : public device::BaseOperator
...
@@ -73,9 +73,39 @@ struct ReferencefpAintBGemm : public device::BaseOperator
ScaleDataType
v_scale
;
ScaleDataType
v_scale
;
ADataType
v_converted_b
;
ADataType
v_converted_b
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
// use PassThrough instead of ConvertBF16RTN for reference calculation
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
arg
.
b_element_op_
(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
// same for scale matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
v_converted_b
=
type_convert
<
ADataType
>
(
v_b
)
*
v_scale
;
v_converted_b
=
type_convert
<
ADataType
>
(
v_b
)
*
v_scale
;
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
f64b1375
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -68,32 +68,21 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -68,32 +68,21 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
if
constexpr
(
is_same_v
<
ADataType
,
pk_i4_t
>
)
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
{
uint8_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
).
data
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
v_a
=
type_convert
<
ComputeTypeA
>
(
i4
);
}
}
else
else
{
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
}
// same for B matrix
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
{
uint8_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
).
data
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
v_b
=
type_convert
<
ComputeTypeB
>
(
i4
);
}
}
else
else
{
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
View file @
f64b1375
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -74,8 +74,26 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
...
@@ -74,8 +74,26 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
// use PassThrough instead of ConvertBF16RTN for reference calculation
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
v_acc
+=
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment