Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dd0188b3
Commit
dd0188b3
authored
Jan 06, 2025
by
Mateusz Ozga
Browse files
Merge remote-tracking branch 'origin/develop' into mozga-amd/universal_gemm_weight
parents
7ff4d613
888317e6
Changes
160
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3284 additions
and
132 deletions
+3284
-132
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+2
-4
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+4
-10
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+10
-9
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+12
-2
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+93
-67
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+2
-1
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+3
-3
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+4
-4
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+7
-13
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+9
-4
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+2
-1
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+4
-0
include/ck/library/utility/host_tensor.hpp
include/ck/library/utility/host_tensor.hpp
+66
-12
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+30
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+167
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+403
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+1248
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+530
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+686
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
View file @
dd0188b3
...
...
@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
...
...
@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
dd0188b3
...
...
@@ -96,9 +96,7 @@ using fmha_epilogue =
{F_spad}, {F_dvpad}>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
...
...
@@ -176,11 +174,7 @@ using fmha_epilogue =
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
...
...
@@ -261,7 +255,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype},
ck_tile::fp8_t
>) {{
if constexpr (std::is_same_v<{F_dtype},
FmhaFwdFp8
>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
...
...
@@ -614,7 +608,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
dd0188b3
...
...
@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
// NOTE: use gpu to do validation
ck_tile
::
naive_attention_fwd_traits
naive_t
;
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
naive_t
.
quant_algo
=
0
;
ck_tile
::
DeviceMem
o_naive_buf
(
o_host
.
get_element_space_size_in_bytes
());
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
dd0188b3
...
...
@@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}();
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
if
constexpr
(
FmhaKernel
::
kIsGroupMode
)
{
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
,
args
.
seqlen_k_ptr
!=
nullptr
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
else
{
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
,
false
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
}
template
<
typename
Kernel
>
...
...
example/ck_tile/02_layernorm2d/generate.py
View file @
dd0188b3
...
...
@@ -58,6 +58,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
...
...
@@ -120,6 +121,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
...
...
@@ -137,6 +139,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
...
...
@@ -152,6 +155,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
kPadN_,
kSaveMeanInvStd_,
kFastFDiv_,
kWelford_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
...
...
@@ -184,6 +188,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
...
...
@@ -204,12 +209,13 @@ float layernorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN,
fals
e>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN,
tru
e>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale,
fals
e, true/*max3*/>>;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale,
UseRawStor
e, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...
...
@@ -274,7 +280,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p
add sweep
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf
welford
2p add sweep
{F_instance_def}
// clang-format on
...
...
@@ -362,6 +368,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN
:
bool
F_kSaveMeanInvStd_
:
bool
F_kFastFDiv_
:
bool
F_kWelford_
:
bool
F_kTwoPass_
:
bool
F_kFusedAdd
:
int
F_kFusedQuant
:
int
...
...
@@ -369,7 +376,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@
property
def
trait_name
(
self
)
->
str
:
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_XScaleDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveMeanInvStd_
):
5
}
,
{
BOOL_MAP
(
self
.
F_kFastFDiv_
):
5
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveMeanInvStd_
):
5
}
,
{
BOOL_MAP
(
self
.
F_kFastFDiv_
):
5
}
,
{
BOOL_MAP
(
self
.
F_kWelford_
):
5
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass_
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedQuant
:
4
}
'
return
t_
...
...
@@ -422,11 +429,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def
name_common_header
(
self
)
->
str
:
return
'layernorm2d_fwd_api_common'
@
property
def
content_api
(
self
)
->
str
:
def
content_api
(
self
,
args
)
->
str
:
# 1 sort based on dtype
t_dtype_dict
=
dict
()
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
for
blob
in
blobs
:
if
blob
.
F_DataTypePair
not
in
t_dtype_dict
:
t_dtype_dict
[
blob
.
F_DataTypePair
]
=
{}
...
...
@@ -462,8 +468,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
#inner_str = inner_str + vec_str
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
(
i_n
<
len
(
blob_per_t
)
-
1
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
isinstance
(
n_
,
int
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)
,
not
isinstance
(
n_
,
int
)
),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
...
...
@@ -474,7 +480,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def
content_common_header
(
self
)
->
str
:
return
self
.
API_COMMON_HEADER
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
)
def
get_blobs
(
self
):
def
get_blobs
(
self
,
args
):
h_traits
=
layernorm_fwd_codegen
.
h_traits
h_instance
=
layernorm_fwd_codegen
.
h_instance
...
...
@@ -484,60 +490,61 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
types_8bit
=
(
'int8'
,
'fp8'
)
types_16bit
=
(
'int16'
,
'fp16'
,
'bf16'
)
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
fused_add_list
=
[
0
,
1
]
fused_sweep_list
=
[
0
,
1
]
# NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
True
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
0
,
0
)]}
# rm rn tm tn vn pd mv fdiv welford 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
True
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
True
,
True
,
0
,
0
)]}
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
...
...
@@ -558,16 +565,27 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_
.
F_YScaleDataType
=
scale_x
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedQuant
=
fused_quant
# disable welford update for 8bit and 16 bit smallN
if
not
h_
.
F_kTwoPass_
:
#disable 16 bit when set args disable_16b_welford
if
args
.
disable_16b_welford
and
prec_i
in
types_16bit
:
h_
.
F_kWelford_
=
False
#disable 8bit by default
elif
prec_i
in
types_8bit
or
prec_o
in
types_8bit
:
h_
.
F_kWelford_
=
False
#disable 16bit small N
elif
prec_i
in
types_16bit
and
hs_key
==
'64'
:
h_
.
F_kWelford_
=
False
current_hs
.
append
(
h_
)
# + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str
=
'big'
if
hs_key
==
'big'
else
current_n
total_blob
.
append
(
h_instance
(
dtype
,
current_n_str
,
fused_add
,
fused_quant
,
current_hs
))
return
total_blob
def
list_blobs
(
self
)
->
None
:
def
list_blobs
(
self
,
args
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
list_p
=
w_p
/
'layernorm2d_fwd_blobs.txt'
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
with
list_p
.
open
(
'w'
)
as
list_f
:
# api related file
list_f
.
write
(
str
(
w_p
/
(
self
.
name_api
+
".cpp"
))
+
"
\n
"
)
...
...
@@ -576,11 +594,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
for
b
in
blobs
:
list_f
.
write
(
str
(
w_p
/
(
b
.
name
+
".cpp"
))
+
"
\n
"
)
def
gen_blobs
(
self
)
->
None
:
def
gen_blobs
(
self
,
args
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
self
.
content_api
)
w_str
=
self
.
content_api
(
args
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
w_str
)
(
w_p
/
(
self
.
name_common_header
+
".hpp"
)).
write_text
(
self
.
content_common_header
)
blobs
=
self
.
get_blobs
()
blobs
=
self
.
get_blobs
(
args
)
for
b
in
blobs
:
(
w_p
/
(
b
.
name
+
".cpp"
)).
write_text
(
b
.
content
)
...
...
@@ -588,14 +607,14 @@ def list_blobs(args):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
layernorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
()
layernorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
(
args
)
def
gen_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
layernorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
()
layernorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
(
args
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -663,6 +682,13 @@ if __name__ == "__main__":
help
=
"codegen receipt."
)
parser
.
add_argument
(
"--disable_16b_welford"
,
default
=
False
,
required
=
False
,
help
=
"enable/disable welford for 16bit datatype n > 64"
)
args
=
parser
.
parse_args
()
# print(f'{args.list_blobs}-{args.gen_blobs}')
...
...
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
dd0188b3
...
...
@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
9120
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
dd0188b3
...
...
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
...
...
@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
dd0188b3
...
...
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch
_size
=
arg_parser
.
get_int
(
"
b
"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
k
batch
=
arg_parser
.
get_int
(
"
split_k
"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
...
...
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A
,
stride_B
,
stride_C
,
batch
_size
,
k
batch
,
n_warmup
,
n_repeat
);
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
dd0188b3
...
...
@@ -22,7 +22,7 @@
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
...
...
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
dd0188b3
...
...
@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
dd0188b3
...
...
@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
dd0188b3
...
...
@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
...
...
@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
...
@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B
,
batch_stride_C
,
batch_count
,
kbatch
,
n_warmup
,
n_repeat
);
...
...
include/ck/library/utility/host_tensor.hpp
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os
<<
delim
;
if
constexpr
(
std
::
is_same_v
<
T
,
ck
::
f8_t
>
||
std
::
is_same_v
<
T
,
ck
::
bf8_t
>
)
using
RangeType
=
ck
::
remove_cvref_t
<
decltype
(
v
)
>
;
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
f8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bf8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bhalf_t
>
)
{
os
<<
ck
::
type_convert
<
float
>
(
v
);
}
else
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
pk_i4_t
>
)
{
const
auto
packed_floats
=
ck
::
type_convert
<
ck
::
float2_t
>
(
v
);
const
ck
::
vector_type
<
float
,
2
>
vector_of_floats
{
packed_floats
};
os
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
0
>
{}]
<<
delim
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
1
>
{}];
}
else
{
os
<<
static_cast
<
T
>
(
v
);
...
...
@@ -266,18 +275,18 @@ struct Tensor
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
Lengths
>
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
...
...
@@ -287,7 +296,7 @@ struct Tensor
{
}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
...
...
@@ -322,7 +331,17 @@ struct Tensor
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
(
mDesc
.
GetElementSpaceSize
()
+
1
)
/
2
;
}
else
{
return
mDesc
.
GetElementSpaceSize
();
}
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
...
...
@@ -469,29 +488,64 @@ struct Tensor
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
include/ck/library/utility/host_tensor_generator.hpp
View file @
dd0188b3
...
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
pk_i4_t
>
{
int8_t
value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
t
=
value
+
8
;
ck
::
pk_i4_t
r
=
((
t
<<
4
)
+
t
)
&
0xff
;
return
r
;
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
};
template
<
>
struct
GeneratorTensor_2
<
ck
::
pk_i4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
hi
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
int
lo
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
ck
::
pk_i4_t
r
=
((
hi
<<
4
)
+
lo
)
&
0xff
;
return
r
;
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
...
...
include/ck/tensor/static_tensor.hpp
View file @
dd0188b3
...
...
@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
...
...
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
0 → 100644
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
0 → 100644
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
// BScale Thread Copy
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
// ABlockCopy
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
// BBlockCopy
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
// CThread
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num_loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
// assume kperblock = scaleblockk
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
0 → 100644
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
static
constexpr
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
const
BScaleGridDesc
&
b_scale_grid_desc
,
// BScaleThreadCopy
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// need?
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc,
// make_tuple(n0, I0),
// b_scale_thread_buf);
// b_scale_thread_copy.MoveSrcSliceWindow(
// b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
// });
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
// b_scale_thread_copy_step.At(Number<1>{}));
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc,
// make_tuple(n0, I0),
// b_scale_thread_buf);
// b_scale_thread_copy.MoveSrcSliceWindow(
// b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
// });
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
// b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
MRepeat
*
KPerInnerLoop
>
{},
Number
<
MRepeat
*
KPerInnerLoop
>
{},
I1
));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
0 → 100644
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v3_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
__device__
static
constexpr
auto
HotLoopScheduler
()
{
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
constexpr
auto
num_scale_k_block
=
BScaleThreadDesc
{}.
GetLength
(
I1
);
constexpr
auto
num_scale_krepeat
=
KRepeat
/
num_scale_k_block
;
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
Number
<
n0
*
num_scale_k_block
+
k0
/
num_scale_krepeat
>
{}],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
Number
<
n0
*
num_scale_k_block
+
k0
/
num_scale_krepeat
>
{}],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
0 → 100644
View file @
dd0188b3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimimal pipeline with highest resource request
// GlobalPrefetchStages: 4
// LocalPreFillStages: 2
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v4_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
return
TailNumber
::
Odd
;
}
else
{
return
TailNumber
::
Even
;
}
}
__device__
static
constexpr
void
HotLoopScheduler
()
{
// TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_issue_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_a
=
(
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
num_issue_a
-
1
)
/
num_issue_a
;
constexpr
auto
num_dsread_per_issue_a
=
num_ds_read_inst_a
/
num_issue_a
;
constexpr
auto
num_issue_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_b
=
(
HotLoopInstList
::
B_LDS_Write_Inst_Num
+
num_issue_b
-
1
)
/
num_issue_b
;
constexpr
auto
num_dsread_per_issue_b
=
num_ds_read_inst_b
/
num_issue_b
;
constexpr
auto
num_mfma_per_issue
=
HotLoopInstList
::
C_MFMA_Inst_Num
/
(
num_issue_a
+
num_issue_b
);
static_for
<
0
,
num_issue_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_a
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_issue_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_b
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_scale_thread_buf
),
Number
<
2
>
{}
>
b_scale_thread_bufs
;
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I0
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
));
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I1
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
2
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_scale_thread_bufs
(
I0
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
// Local prefill 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
));
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I0
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
3
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
// B scale copy
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
4
+
mfma_reg_buf
.
value
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
LoopFunc
(
I1
,
I1
,
I0
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
};
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
);
ReadCompFunc
(
I0
,
I0
,
I1
);
CompFunc
(
I0
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadCompFunc
(
I1
,
I1
,
I0
);
CompFunc
(
I1
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment