Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6fd51c43
Commit
6fd51c43
authored
Dec 05, 2024
by
coderfeli
Browse files
rm useless code
parent
8d2f2f8c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
314 deletions
+188
-314
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+27
-57
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+0
-15
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+91
-140
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+70
-102
No files found.
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
6fd51c43
...
...
@@ -11,6 +11,13 @@ namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
// diff from v1:
// 1. use mwarp x nwarp = 2x2
// 2. use 32x32x16 block gemm
// 3. expose a lds, b lds distribution
// 4. impl a subtile for output c shuffle sub tile construct
// 5. reformat some code.
// todo: merge these using universal gemm
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV2
{
...
...
@@ -44,22 +51,8 @@ struct BlockGemmARegBRegCRegV2
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
MakeABlockDistribution
();
constexpr
auto
b_block_dstr_encode
=
MakeBBlockDistribution
();
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -69,12 +62,6 @@ struct BlockGemmARegBRegCRegV2
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
...
...
@@ -169,36 +156,29 @@ struct BlockGemmARegBRegCRegV2
return
c_block_tensor
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockSubTile
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// for cshuffle, disable currently
// CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
// {
// constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<0, 1>>,
// sequence<2>,
// sequence<0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// return c_block_tensor;
// }
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistribution
()
{
// M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
// <4, 2>, <2>
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
...
...
@@ -224,16 +204,6 @@ struct BlockGemmARegBRegCRegV2
constexpr
auto
b_block_dstr
=
make_static_tile_distribution
(
b_block_dstr_encode
);
return
b_block_dstr
;
// return make_static_distributed_tensor<BDataType>(b_block_dstr);
}
// Prefetch lds
template
<
typename
BlockWindow
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
void
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
// load_tile(block_tensor, make_tile_window(block_window, tileDist));
load_tile
(
block_tensor
,
make_tile_window_linear
(
block_window
,
tileDist
));
}
// C = A * B
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
6fd51c43
...
...
@@ -25,7 +25,6 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
...
...
@@ -214,20 +213,6 @@ struct GemmKernel
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
// using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// {
// CSubTileDistr c_sub_tile;
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// merge_sequences(sequence<1>{}, c_sub_y_lengths));
// EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
// move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6fd51c43
...
...
@@ -39,18 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kHasHotLoop
=
Problem
::
kHasHotLoop
;
static
constexpr
auto
kTailNum
=
Problem
::
kTailNum
;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
// return integer_least_multiple(
// sizeof(ADataType) *
// Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2 +
// integer_least_multiple(
// sizeof(BDataType) *
// Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2;
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -58,7 +46,7 @@ struct GemmPipelineAGmemBGmemCRegV1
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
static
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
SrcTileWindow
&
dram_tile_window
)
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
kKPerBlock
});
...
...
@@ -66,84 +54,80 @@ struct GemmPipelineAGmemBGmemCRegV1
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
static
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
static
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
// schedule
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
// constexpr index_t A_LDS_Read_Width = KPerXDL;//8
// constexpr index_t B_LDS_Read_Width = KPerXDL;//8
// constexpr index_t num_buffer_load_inst_a =
// kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
// constexpr index_t num_buffer_load_inst_b =
// kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
// constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL); // 64
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
// constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
// constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
// constexpr auto num_issue = num_buffer_load_inst; // 8
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
// __builtin_amdgcn_sched_barrier(0);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
0
>
{});
//32
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
1
>
{});
//32
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
//8
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
0
>
{});
//2
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
1
>
{});
//2
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
//8
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
//8
constexpr
index_t
num_buffer_load_inst_a
=
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeA
);
// 4
constexpr
index_t
num_buffer_load_inst_b
=
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeB
);
// 4
constexpr
index_t
num_ds_write_inst_a
=
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
constexpr
index_t
num_ds_write_inst_b
=
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
constexpr
index_t
num_mfma_inst
=
kMPerBlock
*
kNPerBlock
*
kKPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
// 16
constexpr
auto
num_ds_write_inst
=
num_ds_write_inst_a
+
num_ds_write_inst_b
;
//8
constexpr
auto
num_buffer_load_inst
=
num_buffer_load_inst_a
+
num_buffer_load_inst_b
;
//8
constexpr
auto
num_issue
=
num_buffer_load_inst
;
// 8
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
2
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
5
,
0
);
// MFMA : 5
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockSubTile
()
{
...
...
@@ -226,15 +210,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
b_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
// A LDS tile window for store
auto
a_lds_window0
=
make_tile_window
_linear
(
auto
a_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ABlockTileDistr
);
auto
a_lds_window1
=
make_tile_window
_linear
(
auto
a_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ABlockTileDistr
);
// B LDS tile window for store
auto
b_lds_window0
=
make_tile_window
_linear
(
auto
b_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
auto
b_lds_window1
=
make_tile_window
_linear
(
auto
b_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
// Block GEMM
...
...
@@ -253,8 +237,6 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
// local prefetch 0
// a b register tile for lds prefetch & mfma
constexpr
auto
ALdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
()){};
constexpr
auto
BLdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
()){};
...
...
@@ -267,8 +249,10 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_ld_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_ld_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
// local prefetch 0
// a b register tile for lds prefetch & mfma
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
// LDS write 1
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
...
...
@@ -278,43 +262,43 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile1
;
if
(
kHasHotLoop
)
{
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
do
{
// ping
{
block_sync_lds
();
//prefetch lds -> vgpr
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
//prefill -> lds
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
//prefill global -> vgpr
// GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile
(
a_global_load_tile
,
a_copy_dram_window
);
load_tile
(
b_global_load_tile
,
b_copy_dram_window
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// pong
{
block_sync_lds
();
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
//prefetch lds -> vgpr
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
//prefill -> lds
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
//prefill global -> vgpr
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -328,9 +312,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
...
...
@@ -338,14 +321,14 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
load_tile
(
b_block_tile0
,
b_lds_ld_window0
);
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
//1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
...
...
@@ -353,9 +336,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2
{
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
load_tile
(
b_block_tile1
,
b_lds_ld_window1
);
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
...
...
@@ -365,21 +347,11 @@ struct GemmPipelineAGmemBGmemCRegV1
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// 2
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_group_barrier
(
0x008
,
64
,
0
);
// MFMA
__builtin_amdgcn_sched_barrier
(
0
);
}
}
/// cccccccccc
// constexpr auto c_lds_block_desc = Policy::template MakeCLdsBlockDescriptor<Problem>();
// auto c_lds_block = make_tensor_view<address_space_enum::lds>(reinterpret_cast<CDataType*>(p_smem), c_lds_block_desc);
// auto c_lds_window0 = make_tile_window(c_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
// store_tile(c_lds_window0, c_block_tile);
// block_sync_lds();
return
c_block_tile
;
}
...
...
@@ -401,25 +373,4 @@ struct GemmPipelineAGmemBGmemCRegV1
}
};
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
6fd51c43
...
...
@@ -16,37 +16,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
using
BlockGemm
=
BlockGemmARegBRegCRegV2
<
Problem
,
BlockGemmPolicy
>
;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
1
#if 1
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
...
@@ -88,8 +58,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
// make_tuple(make_pass_through_transform(kNPerBlock),
// make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
8
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
...
...
@@ -135,76 +103,76 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#el
if 1
#el
se
// fake XOR
//
template <typename Problem>
//
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
//
{
//
using namespace ck_tile;
//
using ADataType = remove_cvref_t<typename Problem::ADataType>;
//
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
//
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
//
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
//
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
//
number<kKPerBlock>{});
//
constexpr index_t kK1 = 16 / sizeof(ADataType);
//
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
//
a_lds_block_desc_d1_d2_d3,
//
make_tuple(
//
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
//
make_pass_through_transform(2)),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}));
//
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
//
a_lds_block_desc_d4_d5_d6,
//
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
//
make_pass_through_transform(kKPerBlock)),
//
make_tuple(sequence<0, 1>{}, sequence<2>{}),
//
make_tuple(sequence<0>{}, sequence<1>{}));
//
return a_lds_block_desc_m_k;
//
}
//
//
fake XOR
//
template <typename Problem>
//
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
//
{
//
using namespace ck_tile;
//
using BDataType = remove_cvref_t<typename Problem::BDataType>;
//
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
//
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
//
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
//
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
//
number<kKPerBlock>{});
//
constexpr index_t kK1 = 16 / sizeof(BDataType);
//
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
//
b_lds_block_desc_d1_d2_d3,
//
make_tuple(
//
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
//
make_pass_through_transform(2)),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}));
//
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
//
b_lds_block_desc_d4_d5_d6,
//
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
//
make_pass_through_transform(kKPerBlock)),
//
make_tuple(sequence<0, 1>{}, sequence<2>{}),
//
make_tuple(sequence<0>{}, sequence<1>{}));
//
return b_lds_block_desc_n_k;
//
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment