Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6e3c786e
Commit
6e3c786e
authored
Dec 06, 2024
by
Jing Zhang
Browse files
merge develop
parents
1bb510cb
261f1759
Changes
465
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1170 additions
and
31 deletions
+1170
-31
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+111
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+383
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+543
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+73
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+60
-31
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
>
struct
GemmPipelineAgBgCrImplBase
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
return
make_tuple
(
std
::
move
(
a_lds_block
),
std
::
move
(
b_lds_block
));
}
template
<
typename
ADramBlockWindowTmp
,
typename
ALdsTensorView
>
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
{
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
a_copy_dram_window
),
std
::
move
(
a_copy_lds_window
),
std
::
move
(
a_lds_gemm_window
));
}
template
<
typename
BDramBlockWindowTmp
,
typename
BLdsTensorView
>
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
{
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
b_copy_dram_window
),
std
::
move
(
b_copy_lds_window
),
std
::
move
(
b_lds_gemm_window
));
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV3
{
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto
&&
[
a_lds_block
,
b_lds_block
]
=
Base
::
GetABLdsTensorViews
(
p_smem
);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
&&
[
a_copy_dram_window
,
a_copy_lds_window
,
a_lds_gemm_window
]
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
&&
[
b_copy_dram_window
,
b_copy_lds_window
,
b_lds_gemm_window
]
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
ABlockTile
a_block_tile
;
BBlockTile
b_block_tile
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrMem
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
// TODO: Is this 32K value gfx9 arch specific?
static
constexpr
index_t
MinMemInFlyBytes
=
32768
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
MinMemInFlyBytes
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
LocalPrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
};
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Interwave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
struct
GemmPipelineScheduler
{
Default
,
Intrawave
,
Interwave
,
};
enum
struct
TailNumber
{
// Single / Double buffer pipeline
Odd
,
Even
,
// Long prefetch pipeline, up to 8
One
,
Two
,
Three
,
Four
,
Five
,
Six
,
Seven
,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty
,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full
,
};
}
// namespace ck_tile
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
GemmPipelineScheduler
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
GemmPipelineScheduler
::
Default
:
os
<<
"Default"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
TailNumber
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
TailNumber
::
Odd
:
os
<<
"Odd"
;
break
;
case
ck_tile
::
TailNumber
::
Even
:
os
<<
"Even"
;
break
;
case
ck_tile
::
TailNumber
::
One
:
os
<<
"One"
;
break
;
case
ck_tile
::
TailNumber
::
Two
:
os
<<
"Two"
;
break
;
case
ck_tile
::
TailNumber
::
Three
:
os
<<
"Three"
;
break
;
case
ck_tile
::
TailNumber
::
Four
:
os
<<
"Four"
;
break
;
case
ck_tile
::
TailNumber
::
Five
:
os
<<
"Five"
;
break
;
case
ck_tile
::
TailNumber
::
Six
:
os
<<
"Six"
;
break
;
case
ck_tile
::
TailNumber
::
Seven
:
os
<<
"Seven"
;
break
;
case
ck_tile
::
TailNumber
::
Empty
:
os
<<
"Empty"
;
break
;
case
ck_tile
::
TailNumber
::
Full
:
os
<<
"Full"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
Alignment
A
=
Problem
::
Alignment
A
;
static
constexpr
index_t
Alignment
B
=
Problem
::
Alignment
B
;
static
constexpr
index_t
Alignment
C
=
Problem
::
Alignment
C
;
static
constexpr
index_t
VectorSize
A
=
Problem
::
VectorSize
A
;
static
constexpr
index_t
VectorSize
B
=
Problem
::
VectorSize
B
;
static
constexpr
index_t
VectorSize
C
=
Problem
::
VectorSize
C
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
...
...
@@ -44,7 +48,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
...
...
@@ -97,11 +101,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B DRAM tile window for load
auto
b_copy_dram_window
=
...
...
@@ -111,11 +112,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
...
...
@@ -126,7 +124,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
...
...
@@ -145,12 +143,32 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
index_t
iCounter
=
num_loop
-
1
;
...
...
@@ -176,8 +194,19 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
}
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment