Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6a07464b
Commit
6a07464b
authored
Nov 28, 2024
by
coderfeli
Browse files
change ways but still could not use immediate data as ds_read
parent
405c05c0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
142 additions
and
253 deletions
+142
-253
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+2
-1
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+11
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+12
-12
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+5
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+111
-229
No files found.
cmake/EnableCompilerWarnings.cmake
View file @
6a07464b
...
...
@@ -66,6 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marker
# -Werror
-Wno-option-ignored
-Wsign-compare
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
6a07464b
...
...
@@ -82,7 +82,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"init"
,
"0"
,
"0:random, 1:linear, 2:constant(1)"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
6a07464b
...
...
@@ -69,6 +69,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_size = arg_parser.get_int("
b
");
int n_warmup = arg_parser.get_int("
warmup
");
int n_repeat = arg_parser.get_int("
repeat
");
ck_tile::index_t init_method = arg_parser.get_int("
init
");
using namespace ck_tile::literals;
...
...
@@ -114,14 +115,16 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<ADataType>
{
1.f
}
(a_m_k);
// ck_tile::FillConstant<BDataType>
{
1.f
}
(b_k_n);
if (init_method == 0) {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
} else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else {
ck_tile::FillConstant<ADataType>
{
1.f
}
(a_m_k);
ck_tile::FillConstant<BDataType>
{
1.f
}
(b_k_n);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
6a07464b
...
...
@@ -374,29 +374,29 @@ struct BlockwiseGemmXdlops_pipeline_v4
{
// schedule
constexpr
auto
num_ds_read_inst
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
+
HotLoopInstList
::
B_LDS_Read_Inst_Num
;
HotLoopInstList
::
A_LDS_Read_Inst_Num
+
HotLoopInstList
::
B_LDS_Read_Inst_Num
;
//16
constexpr
auto
num_ds_write_inst
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
//8
;
constexpr
auto
num_buffer_load_inst
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
+
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
HotLoopInstList
::
A_Buffer_Load_Inst_Num
+
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
//8
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
//64
constexpr
auto
num_issue
=
num_buffer_load_inst
;
constexpr
auto
num_issue
=
num_buffer_load_inst
;
// 8
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
: 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_buffer_load_inst
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
0x100
,
num_ds_read_inst
/
num_buffer_load_inst
,
0
);
// DS read
: 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_buffer_load_inst
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
0x200
,
num_ds_write_inst
/
num_buffer_load_inst
,
0
);
// DS write
: 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
: 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
:1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_buffer_load_inst
-
3
,
0
);
// MFMA
0x008
,
num_mfma_inst
/
num_buffer_load_inst
-
3
,
0
);
// MFMA
: 5
});
}
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
6a07464b
...
...
@@ -184,7 +184,6 @@ struct BlockGemmARegBRegCRegV2
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
return
a_block_dstr
;
// return make_static_distributed_tensor<ADataType>(a_block_dstr);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistribution
()
...
...
@@ -208,10 +207,13 @@ struct BlockGemmARegBRegCRegV2
template
<
typename
BlockWindow
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
auto
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
//.get_static_tile_distribution_encoding()
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
// load_tile_raw(block_tensor, make_tile_window_linear_raw(block_window, tileDist));
// return;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6a07464b
...
...
@@ -71,6 +71,68 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
// schedule
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
0
>
{});
//32
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
1
>
{});
//32
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
//8
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
0
>
{});
//2
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
number
<
1
>
{});
//2
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
//8
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
//8
constexpr
index_t
num_buffer_load_inst_a
=
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeA
);
// 4
constexpr
index_t
num_buffer_load_inst_b
=
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
VectorSizeB
);
// 4
constexpr
index_t
num_ds_write_inst_a
=
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
constexpr
index_t
num_ds_write_inst_b
=
kNPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 4
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
kMPerBlock
*
kKPerBlock
/
(
BlockSize
*
KPerXDL
);
// 8
constexpr
index_t
num_mfma_inst
=
kMPerBlock
*
kNPerBlock
*
kKPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
// 16
constexpr
auto
num_ds_write_inst
=
num_ds_write_inst_a
+
num_ds_write_inst_b
;
//8
constexpr
auto
num_buffer_load_inst
=
num_buffer_load_inst_a
+
num_buffer_load_inst_b
;
//8
constexpr
auto
num_issue
=
num_buffer_load_inst
;
// 8
static_for
<
0
,
num_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_inst
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
@@ -158,27 +220,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
// A LDS tile window for store
auto
a_
store_
lds_window0
=
make_tile_window
(
auto
a_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
a_
store_
lds_window1
=
make_tile_window
(
auto
a_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B LDS tile window for store
auto
b_store_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
b_store_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
auto
a_load_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
a_load_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_load_lds_window0
=
make_tile_window
(
auto
b_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
b_
load_
lds_window1
=
make_tile_window
(
auto
b_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
...
...
@@ -188,76 +238,62 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
c_block_tile
=
Policy
::
template
BlockGemm
<
Problem
>
::
MakeCBlockTile
();
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// a b register tile for lds prefetch & mfma
auto
a_block_tile0
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
a_block_tile1
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
b_block_tile0
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
auto
b_block_tile1
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
// LDS write 0
LocalPrefill
(
a_
store_
lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_
store_
lds_window0
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
// global read 1
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
// local prefetch 0
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_block_tile0
);
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill
(
a_store_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window1
,
b_global_load_tile
,
b_element_func
);
// a b register tile for lds prefetch & mfma
using
ALdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
using
BLdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
{}));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
{}));
ALdsTile
a_block_tile0
;
BLdsTile
b_block_tile0
;
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{}));
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{}));
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
// global read 2
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
index_t
iCounter
=
num_loop
-
2
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile1
;
while
(
iCounter
>
1
)
{
// ping
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_
store_
lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_
store_
lds_window0
,
b_global_load_tile
,
b_element_func
);
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
}
// pong
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_block_tile0
);
LocalPrefill
(
a_
store_
lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_
store_
lds_window1
,
b_global_load_tile
,
b_element_func
);
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{})
);
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{})
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
}
iCounter
-=
2
;
}
...
...
@@ -267,17 +303,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_
store_
lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_
store_
lds_window0
,
b_global_load_tile
,
b_element_func
);
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_block_tile0
);
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{})
);
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{})
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
//1
...
...
@@ -288,13 +324,23 @@ struct GemmPipelineAGmemBGmemCRegV1
}
else
{
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; %f, %f. ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)), type_convert<float>(a_block_tile1(i_j_idx)), type_convert<float>(b_block_tile1(i_j_idx)));
// });
// printf("\n");
// });
// }
}
}
return
c_block_tile
;
...
...
@@ -316,170 +362,6 @@ struct GemmPipelineAGmemBGmemCRegV1
}
};
// __device__ static constexpr auto HotLoopScheduler()
// {
// // schedule
// constexpr auto num_ds_read_inst =
// HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
// constexpr auto num_ds_write_inst =
// HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
// ;
// constexpr auto num_buffer_load_inst =
// HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
// ;
// constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// constexpr auto num_issue = num_buffer_load_inst;
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
// });
// }
// CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
// {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});
// constexpr index_t A_LDS_Read_Width = KPerXDL;
// constexpr index_t B_LDS_Read_Width = KPerXDL;
// constexpr index_t A_Buffer_Load_Inst_Num =
// MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
// constexpr index_t B_Buffer_Load_Inst_Num =
// NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
// constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL);
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
// constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
// constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
// constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
// constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
// constexpr auto ds_read_a_issue_cycle =
// A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
// constexpr auto ds_read_b_issue_cycle =
// B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
// constexpr auto ds_read_a_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
// constexpr auto ds_read_b_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// constexpr auto num_dsread_b_mfma =
// (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// // stage 1
// // Separate this part?
// // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// // sizeof(ComputeDataType) /
// // sizeof(BDataType)
// // ? sizeof(ComputeDataType) /
// // sizeof(ADataType) : sizeof(ComputeDataType)
// // / sizeof(BDataType);
// constexpr auto num_mfma_stage1 =
// num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
// constexpr auto num_mfma_per_issue =
// num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
// });
// static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
// });
// // stage 2
// static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
// ds_read_a_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
// ds_read_b_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment