Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c275904b
Commit
c275904b
authored
Dec 03, 2024
by
coderfeli
Browse files
try to fix hint
parent
730c5fff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
137 additions
and
92 deletions
+137
-92
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+20
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+117
-92
No files found.
include/ck_tile/core/tensor/store_tile.hpp
View file @
c275904b
...
@@ -76,6 +76,26 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
...
@@ -76,6 +76,26 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
tile_window
.
store
(
dstr_tensor
,
number
<-
1
>
{});
tile_window
.
store
(
dstr_tensor
,
number
<-
1
>
{});
}
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
c275904b
...
@@ -76,75 +76,74 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -76,75 +76,74 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
{
// schedule
// schedule
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
0
>
{});
//32
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
1
>
{});
//32
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
//8
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr
index_t
WaveSize
=
64
;
// constexpr index_t WaveSize = 64;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
0
>
{});
//2
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
1
>
{});
//2
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
//8
// constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
//8
// constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr
index_t
num_buffer_load_inst_a
=
// constexpr index_t num_buffer_load_inst_a =
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeA
);
// 4
// kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr
index_t
num_buffer_load_inst_b
=
// constexpr index_t num_buffer_load_inst_b =
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeB
);
// 4
// kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr
index_t
num_ds_write_inst_a
=
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
// constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr
index_t
num_ds_write_inst_b
=
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
// constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr
index_t
A_LDS_Read_Inst_Num
=
// constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
// WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr
index_t
B_LDS_Read_Inst_Num
=
// constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
// WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr
index_t
num_mfma_inst
=
kMPerBlock
*
kNPerBlock
*
kKPerBlock
/
// constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(
BlockSize
/
WaveSize
)
/
// (BlockSize / WaveSize) /
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// 64
// (MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule
// // A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
?
A_LDS_Read_Inst_Num
// ? A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
// : A_LDS_Read_Inst_Num / 2;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
?
B_LDS_Read_Inst_Num
// ? B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
// : B_LDS_Read_Inst_Num / 2;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
// 16
// constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr
auto
num_ds_write_inst
=
num_ds_write_inst_a
+
num_ds_write_inst_b
;
//8
// constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr
auto
num_buffer_load_inst
=
num_buffer_load_inst_a
+
num_buffer_load_inst_b
;
//8
// constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr
auto
num_issue
=
num_buffer_load_inst
;
// 8
// constexpr auto num_issue = num_buffer_load_inst; // 8
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
// static_for<0, num_issue, 1>{}([&](auto i) {
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
// static_for<0, 8, 1>{}([&](auto i) {
// ignore = i;
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
// });
__builtin_amdgcn_sched_barrier
(
0
);
// __builtin_amdgcn_sched_barrier(0);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
2
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
5
,
0
);
// MFMA : 5
});
}
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockSubTile
()
{
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockSubTile
()
{
...
@@ -180,23 +179,23 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -180,23 +179,23 @@ struct GemmPipelineAGmemBGmemCRegV1
////////////// global window & register /////////////////
////////////// global window & register /////////////////
// A DRAM tile window for load
// A DRAM tile window for load
auto
a_copy_dram_window
=
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
_linear
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// B DRAM tile window for load
// B DRAM tile window for load
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
_linear
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A register tile for global load
// A register tile for global load
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
()
)
;
constexpr
auto
ABlockTileDistr
=
a_copy_dram_window
.
get_tile_distribution
();
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
()
)
;
constexpr
auto
BBlockTileDistr
=
b_copy_dram_window
.
get_tile_distribution
();
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}
));
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}
));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
));
ABlockTile
a_global_load_tile
;
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
BBlockTile
b_global_load_tile
;
...
@@ -213,27 +212,35 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -213,27 +212,35 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr
index_t
b_lds_block_space_size_aligned
=
constexpr
index_t
b_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
BDataType
)
*
b_lds_block_desc
.
get_element_space_size
(),
16
);
integer_least_multiple
(
sizeof
(
BDataType
)
*
b_lds_block_desc
.
get_element_space_size
(),
16
);
// A tile in LDS view
// A tile in LDS view
ADataType
*
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
const
ADataType
*
__restrict__
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_a_lds0
)
+
a_lds_block_space_size_aligned
);
const
ADataType
*
__restrict__
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
const
ADataType
*
__restrict__
p_a_lds2
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
const
ADataType
*
__restrict__
p_a_lds3
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
auto
a_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds2
,
a_lds_block_desc
);
auto
a_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds3
,
a_lds_block_desc
);
// B tile in LDS view
// B tile in LDS view
BDataType
*
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_a_lds1
)
+
a_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
BDataType
*
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_b_lds0
)
+
b_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
+
b_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds2
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
const
BDataType
*
__restrict__
p_b_lds3
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
+
b_lds_block_space_size_aligned
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
auto
b_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds2
,
b_lds_block_desc
);
auto
b_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds3
,
b_lds_block_desc
);
// A LDS tile window for store
// A LDS tile window for store
auto
a_lds_window0
=
make_tile_window
(
auto
a_lds_window0
=
make_tile_window
_linear
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
ABlockTileDistr
);
auto
a_lds_window1
=
make_tile_window
(
auto
a_lds_window1
=
make_tile_window
_linear
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
ABlockTileDistr
);
// B LDS tile window for store
// B LDS tile window for store
auto
b_lds_window0
=
make_tile_window
(
auto
b_lds_window0
=
make_tile_window
_linear
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
BBlockTileDistr
);
auto
b_lds_window1
=
make_tile_window
(
auto
b_lds_window1
=
make_tile_window
_linear
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
BBlockTileDistr
);
// Block GEMM
// Block GEMM
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
...
@@ -260,10 +267,10 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -260,10 +267,10 @@ struct GemmPipelineAGmemBGmemCRegV1
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile0
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile0
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_
window0
,
ALdsTileDistr
);
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_
ld_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_
window1
,
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_
ld_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_
window0
,
BLdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_
ld_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_
window1
,
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_
ld_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
}
,
BLdsTileDistr
);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
...
@@ -276,7 +283,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -276,7 +283,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
index_t
iCounter
=
num_loop
-
2
;
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
)
;
ALdsTile
a_block_tile1
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile1
;
BLdsTile
b_block_tile1
;
...
@@ -286,19 +293,27 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -286,19 +293,27 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
// ping
{
{
block_sync_lds
();
block_sync_lds
();
//prefetch lds -> vgpr
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
//prefill -> lds
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
//prefill global -> vgpr
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile
(
a_global_load_tile
,
a_copy_dram_window
);
load_tile
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// pong
// pong
{
{
block_sync_lds
();
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
...
@@ -307,6 +322,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -307,6 +322,7 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
iCounter
-=
2
;
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
while
(
iCounter
>
1
);
...
@@ -346,10 +362,19 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -346,10 +362,19 @@ struct GemmPipelineAGmemBGmemCRegV1
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
8
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// 2
// 2
{
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_group_barrier
(
0x008
,
64
,
0
);
// MFMA
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment