Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
39ad271b
"vscode:/vscode.git/clone" did not exist on "27cad566faf9eb534ddcfe671100e5130a6108fc"
Commit
39ad271b
authored
Jul 12, 2024
by
danyao12
Browse files
codegen update
parent
39fc3d4b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
373 additions
and
171 deletions
+373
-171
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+373
-171
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
39ad271b
...
@@ -14,15 +14,11 @@ from codegen.cpp_symbol_map import *
...
@@ -14,15 +14,11 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP
=
{
BWD_DQDKDV_PIPELINE_MAP
=
{
"ks_kts_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR"
,
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR"
,
"qs_ks_vr_dos"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS"
,
"ks_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR"
,
}
}
BWD_DQDKDV_PIPELINE_ENUM_MAP
=
{
BWD_DQDKDV_PIPELINE_ENUM_MAP
=
{
"ks_kts_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR"
,
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR"
,
"qs_ks_vr_dos"
:
"ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS"
,
"ks_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KSVR"
,
}
}
FMHA_BWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
FMHA_BWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
...
@@ -34,39 +30,41 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
...
@@ -34,39 +30,41 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
=
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G1&G3 -> GdKV
// G4 -> GdQ
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_warp_tile
0
_{F_idx},
fmha_block_warps1_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_warp_tile
1
_{F_idx},
fmha_block_warps0_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_warp_tile
0
_{F_idx},
fmha_block_warps1_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_warp_tile
1
_{F_idx},
fmha_block_warps2_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
fmha_warp_tile
0
_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_skpad},
{F_dpad},
{F_dpad},
{F_dvpad},
{F_dvpad},
{F_bias},
{F_bias},
{F_dbias},
{F_dbias},
false,
false,
{F_dropout}
,
false
,
false,
{F_occupancy}>;
{F_occupancy}>
;
using fmha_mask_{F_idx} = {F_mask}
;
using fmha_
mask
_{F_idx} = {F_
mask
};
using fmha_
dropout
_{F_idx}
= {F_
dropout
};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...
@@ -86,55 +84,73 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
...
@@ -86,55 +84,73 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
fmha_bwd_shape_{F_idx},
{F_mode},
{F_mode},
{F_deterministic},
fmha_mask_{F_idx},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_bwd_trait_{F_idx}>;
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
false,
false>>;
using fmha_bwd_dv_epilogue_{F_idx} =
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_pipeline_enum},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
#include <iostream>
#include <iostream>
template<>
template
<>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
auto [kargs, grids]
= fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr dim3 blocks
= k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
}}
template<>
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
using k_
= fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
auto [kargs, grids]
= fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr dim3 blocks
= k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
}}
template<>
template
<>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
...
@@ -146,14 +162,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
...
@@ -146,14 +162,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API
=
"""
FMHA_BWD_API
=
"""
#include <iostream>
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
template
<typename dot_do_o_trait_, typename dq_dk_dv_trait_
, typename convert_dq_trait_
>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() <<
", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() <<
std::flush;
return ck_tile::launch_kernel(s,
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
);
}}
}}
...
@@ -173,38 +190,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
...
@@ -173,38 +190,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
}}
"""
"""
FMHA_BWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
FMHA_BWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
"""
"""
@
dataclass
@
dataclass
class
FmhaBwdDQDKDVApiTrait
:
class
FmhaBwdDQDKDVApiTrait
:
pipeline
:
str
pipeline
:
str
# sync with fmha_bwd_traits<>, to generate fallback calls
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim
:
str
hdim
:
str
dtype
:
str
# data type
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along k seqlen
bn0
:
int
# tile size along k seqlen
bhdq
:
int
# q head_dim
bhdq
:
int
# q head_dim
bhdv
:
int
# v head_dim
bhdv
:
int
# v head_dim
mask
:
str
mask
:
str
bias
:
str
bias
:
str
dbias
:
str
dbias
:
str
dropout
:
str
dropout
:
str
spad
:
str
spad
:
str
skpad
:
str
skpad
:
str
dpad
:
str
dpad
:
str
dvpad
:
str
dvpad
:
str
deterministic
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
pipeline
}
-
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
dbias
}
-
{
self
.
dropout
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
'
def
scheck
(
self
,
spad1
:
str
)
->
str
:
def
scheck
(
self
,
spad1
:
str
)
->
str
:
if
self
.
mode
==
'group'
:
if
self
.
mode
==
'group'
:
...
@@ -212,9 +227,9 @@ class FmhaBwdDQDKDVApiTrait:
...
@@ -212,9 +227,9 @@ class FmhaBwdDQDKDVApiTrait:
elif
self
.
spad
==
't'
and
spad1
==
't'
:
elif
self
.
spad
==
't'
and
spad1
==
't'
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
!= 0'
return
f
'a.seqlen_q %
{
self
.
bm0
}
!= 0'
elif
self
.
spad
==
'f'
and
spad1
==
't'
:
elif
self
.
spad
==
'f'
and
spad1
==
't'
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0 and a.seqlen_q %
25
6 != 0'
# BlockSize
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0 and a.seqlen_q % 6
4
!= 0'
else
:
# self.skpad == 'f' and skpad1 == 'f'
else
:
# self.skpad == 'f' and skpad1 == 'f'
return
f
'a.seqlen_q %
25
6 == 0'
# BlockSize
return
f
'a.seqlen_q % 6
4
== 0'
@
property
@
property
def
skcheck
(
self
)
->
str
:
def
skcheck
(
self
)
->
str
:
...
@@ -256,16 +271,21 @@ class FmhaBwdApiPool:
...
@@ -256,16 +271,21 @@ class FmhaBwdApiPool:
per_hdim_case
=
str
()
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
dq_dk_dv_pool
[
dtype
].
keys
()):
for
j
,
hdim
in
enumerate
(
self
.
dq_dk_dv_pool
[
dtype
].
keys
()):
traits
=
self
.
dq_dk_dv_pool
[
dtype
][
hdim
]
traits
=
self
.
dq_dk_dv_pool
[
dtype
][
hdim
]
hdim_int
=
int
(
hdim
)
inners
=
str
()
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
if_k
=
'if'
if
k
==
0
else
'else if'
for
spad1
in
[
"t"
,
"f"
]:
for
spad1
in
[
"t"
,
"f"
]:
if
((
spad1
==
"f"
and
trait
.
spad
==
"t"
)
or
(
trait
.
mode
==
"group"
and
spad1
==
"f"
)):
if
(
spad1
==
"f"
and
(
trait
.
spad
==
"t"
or
trait
.
mode
==
"group"
)):
continue
if
(
spad1
==
"t"
and
trait
.
spad
==
"f"
and
hdim_int
<=
64
):
continue
continue
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
])
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
...
@@ -300,74 +320,82 @@ class FmhaBwdDQDKDVTileSize:
...
@@ -300,74 +320,82 @@ class FmhaBwdDQDKDVTileSize:
F_rm2
:
int
# number of warps along k seqlen (block warps) in gemm4
F_rm2
:
int
# number of warps along k seqlen (block warps) in gemm4
F_rn2
:
int
# number of warps along q seqlen (block warps) in gemm4
F_rn2
:
int
# number of warps along q seqlen (block warps) in gemm4
F_rk2
:
int
# number of warps along gemm-k (not used) in gemm4
F_rk2
:
int
# number of warps along gemm-k (not used) in gemm4
F_wm
:
int
# warp size along m (warp size)
F_wm0
:
int
# warp size along m in gemm0/gemm2/gemm4
F_wn
:
int
# warp size along n
F_wn0
:
int
# warp size along n in gemm0/gemm2/gemm4
F_wk
:
int
# warp size along k
F_wk0
:
int
# warp size along k in gemm0/gemm2/gemm4
F_wm1
:
int
# warp size along m in gemm1/gemm3
F_wn1
:
int
# warp size along n in gemm1/gemm3
F_wk1
:
int
# warp size along k in gemm1/gemm3
F_occupancy
:
int
# occupancy
F_occupancy
:
int
# occupancy
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk2
}
x
{
self
.
F_bk3
}
x
{
self
.
F_bk4
}
x
{
self
.
F_bhdq
}
x
{
self
.
F_bhdv
}
"
+
\
return
f
"b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
x
{
self
.
F_bk0
}
x
{
self
.
F_bk1
}
x
{
self
.
F_bk2
}
x
{
self
.
F_bk3
}
x
{
self
.
F_bk4
}
x
{
self
.
F_bhdq
}
x
{
self
.
F_bhdv
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
_r
{
self
.
F_rm2
}
x
{
self
.
F_rn2
}
x
{
self
.
F_rk2
}
"
+
\
f
"_r
{
self
.
F_rm0
}
x
{
self
.
F_rn0
}
x
{
self
.
F_rk0
}
_r
{
self
.
F_rm1
}
x
{
self
.
F_rn1
}
x
{
self
.
F_rk1
}
_r
{
self
.
F_rm2
}
x
{
self
.
F_rn2
}
x
{
self
.
F_rk2
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
_o
{
self
.
F_occupancy
}
"
f
"_w
{
self
.
F_wm
0
}
x
{
self
.
F_wn
0
}
x
{
self
.
F_wk
0
}
_w
{
self
.
F_wm1
}
x
{
self
.
F_wn1
}
x
{
self
.
F_wk1
}
_o
{
self
.
F_occupancy
}
"
@
dataclass
@
dataclass
class
FmhaBwdDQDKDVKernel
:
class
FmhaBwdDQDKDVKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_dtype
:
str
# data type
F_tile
:
FmhaBwdDQDKDVTileSize
F_tile
:
FmhaBwdDQDKDVTileSize
F_spad
:
str
# true/false
F_spad
:
str
# true/false
F_skpad
:
str
#
F_skpad
:
str
#
F_dpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
#
F_bias
:
str
#
F_dbias
:
str
#
F_dbias
:
str
#
F_dropout
:
str
#
F_dropout
:
str
#
F_mask
:
str
# value from MASK_MAP
F_mask
:
str
# value from MASK_MAP
F_mode
:
str
# value from MODE_MAP
F_mode
:
str
# value from MODE_MAP
F_pipeline
:
str
F_deterministic
:
str
#
mask_impl
:
str
F_pipeline
:
str
#
mask_impl
:
str
#
@
property
@
property
def
template
(
self
)
->
str
:
def
template
(
self
)
->
str
:
return
FMHA_BWD_KERNEL_HEADER
+
\
return
FMHA_BWD_KERNEL_HEADER
+
\
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk1
=
self
.
F_tile
.
F_bk1
,
F_bk1
=
self
.
F_tile
.
F_bk1
,
F_bk2
=
self
.
F_tile
.
F_bk2
,
F_bk2
=
self
.
F_tile
.
F_bk2
,
F_bk3
=
self
.
F_tile
.
F_bk3
,
F_bk3
=
self
.
F_tile
.
F_bk3
,
F_bk4
=
self
.
F_tile
.
F_bk4
,
F_bk4
=
self
.
F_tile
.
F_bk4
,
F_bhdq
=
self
.
F_tile
.
F_bhdq
,
F_bhdq
=
self
.
F_tile
.
F_bhdq
,
F_bhdv
=
self
.
F_tile
.
F_bhdv
,
F_bhdv
=
self
.
F_tile
.
F_bhdv
,
F_rm0
=
self
.
F_tile
.
F_rm0
,
F_rm0
=
self
.
F_tile
.
F_rm0
,
F_rn0
=
self
.
F_tile
.
F_rn0
,
F_rn0
=
self
.
F_tile
.
F_rn0
,
F_rk0
=
self
.
F_tile
.
F_rk0
,
F_rk0
=
self
.
F_tile
.
F_rk0
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rm1
=
self
.
F_tile
.
F_rm1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rn1
=
self
.
F_tile
.
F_rn1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rk1
=
self
.
F_tile
.
F_rk1
,
F_rm2
=
self
.
F_tile
.
F_rm2
,
F_rm2
=
self
.
F_tile
.
F_rm2
,
F_rn2
=
self
.
F_tile
.
F_rn2
,
F_rn2
=
self
.
F_tile
.
F_rn2
,
F_rk2
=
self
.
F_tile
.
F_rk2
,
F_rk2
=
self
.
F_tile
.
F_rk2
,
F_wm
=
self
.
F_tile
.
F_wm
,
F_wm0
=
self
.
F_tile
.
F_wm0
,
F_wn
=
self
.
F_tile
.
F_wn
,
F_wn0
=
self
.
F_tile
.
F_wn0
,
F_wk
=
self
.
F_tile
.
F_wk
,
F_wk0
=
self
.
F_tile
.
F_wk0
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_wm1
=
self
.
F_tile
.
F_wm1
,
F_skpad
=
BOOL_MAP
[
self
.
F_skpad
],
F_wn1
=
self
.
F_tile
.
F_wn1
,
F_dpad
=
BOOL_MAP
[
self
.
F_dpad
],
F_wk1
=
self
.
F_tile
.
F_wk1
,
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_bias
=
BIAS_MAP
[
self
.
F_bias
],
F_skpad
=
BOOL_MAP
[
self
.
F_skpad
],
F_dbias
=
BOOL_MAP
[
self
.
F_dbias
],
F_dpad
=
BOOL_MAP
[
self
.
F_dpad
],
F_dropout
=
BOOL_MAP
[
self
.
F_dropout
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_bias
=
BIAS_MAP
[
self
.
F_bias
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_mask
],
F_dbias
=
BOOL_MAP
[
self
.
F_dbias
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_dropout
=
DROPOUT_MAP
[
self
.
F_dropout
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_mask
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_deterministic
=
BOOL_MAP
[
self
.
F_deterministic
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
self
.
F_pipeline
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
self
.
F_pipeline
],
F_pipeline
=
BWD_DQDKDV_PIPELINE_MAP
[
self
.
F_pipeline
])
F_pipeline
=
BWD_DQDKDV_PIPELINE_MAP
[
self
.
F_pipeline
])
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
...
@@ -388,7 +416,8 @@ class FmhaBwdDQDKDVKernel:
...
@@ -388,7 +416,8 @@ class FmhaBwdDQDKDVKernel:
if
self
.
F_mask
==
's_mask'
:
n
+=
f
'_mask'
if
self
.
F_mask
==
's_mask'
:
n
+=
f
'_mask'
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_dropout
!=
'no'
:
n
+=
f
'_
{
self
.
F_dropout
}
'
if
self
.
F_deterministic
==
't'
:
n
+=
'_deterministic'
return
n
return
n
@
property
@
property
...
@@ -411,19 +440,23 @@ class FmhaBwdDQDKDVKernel:
...
@@ -411,19 +440,23 @@ class FmhaBwdDQDKDVKernel:
spad
=
self
.
F_spad
,
spad
=
self
.
F_spad
,
skpad
=
self
.
F_skpad
,
skpad
=
self
.
F_skpad
,
dpad
=
self
.
F_dpad
,
dpad
=
self
.
F_dpad
,
dvpad
=
self
.
F_dvpad
)
dvpad
=
self
.
F_dvpad
,
deterministic
=
self
.
F_deterministic
)
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
# this is current supported tile size & pipeline.
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
128
,
128
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
1
),
# '32' : [FmhaBwdDQDKDVTileSize( 64, 64, 32, 64, 32, 64, 64, 32, 32, 1, 2, 1, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, 1),
"qs_ks_vr_dos"
],
# "kr_ktr_vr"],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
32
,
32
,
32
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
1
),
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
64
,
64
,
64
,
64
,
64
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
1
),
"qs_ks_vr_dos"
],
"kr_ktr_vr"
],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
32
,
32
,
32
,
32
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
1
),
# '128' : [FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 32, 32, 16, 32, 32, 16, 1),
"ks_vr"
]
# "kr_ktr_vr"],
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr"]
}
}
else
:
else
:
return
None
return
None
...
@@ -438,7 +471,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -438,7 +471,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
for
hdim_str
,
mode
,
mask
,
bias
,
dbias
,
dropout
,
spad
,
skpad
,
dpad
,
dvpad
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
(),
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
hdim_str
,
mode
,
mask
,
bias
,
dbias
,
dropout
,
spad
,
skpad
,
dpad
,
dvpad
,
deterministic
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
(),
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
DROPOUT_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
tile
=
d
[
hdim_str
][
0
]
tile
=
d
[
hdim_str
][
0
]
ppl
=
d
[
hdim_str
][
1
]
ppl
=
d
[
hdim_str
][
1
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
...
@@ -446,10 +479,12 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -446,10 +479,12 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
continue
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
continue
continue
if
((
hdim
<=
128
and
(
"wg16"
in
dropout
))
or
(
hdim
==
256
and
(
"wg32"
in
dropout
))):
continue
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_bias
=
bias
,
F_dbias
=
dbias
,
F_dropout
=
dropout
,
F_mask
=
mask
,
F_mode
=
mode
,
F_bias
=
bias
,
F_dbias
=
dbias
,
F_dropout
=
dropout
,
F_mask
=
mask
,
F_mode
=
mode
,
F_pipeline
=
ppl
,
mask_impl
=
mask_impl
)
F_pipeline
=
ppl
,
mask_impl
=
mask_impl
,
F_deterministic
=
deterministic
)
if
kernel_filter
!=
None
:
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
continue
...
@@ -466,53 +501,55 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -466,53 +501,55 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
FMHA_BWD_DOT_DO_O_KERNEL_BODY
=
"""
FMHA_BWD_DOT_DO_O_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
using fmha_bwd_dot_do_o_trait_{F_idx} =
{F_dvpad},
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
{F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */
25
6,
/* BlockSize = */ 6
4
,
{F_hdim},
{F_hdim},
{F_mode},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<
using fmha_bwd_dot_do_o_{F_idx} =
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwd
OGradDotO
TilePartitioner</* BlockSize = */
25
6>,
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwd
Q
TilePartitioner</* BlockSize = */ 6
4
>,
fmha_bwd_dot_do_o_{F_idx}>;
fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
#include <iostream>
template<>
template
<>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
auto [kargs, grids]
= fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr dim3 blocks
= k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
}}
template<>
template
<>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
using k_
= fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
auto [kargs, grids]
= fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr dim3 blocks
= k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
}}
template<>
template
<>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
...
@@ -582,12 +619,171 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
...
@@ -582,12 +619,171 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return
gen
return
gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_hdim}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_bwd_convert_dq_shape_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
fmha_bwd_convert_dq_shape_{F_idx},
fmha_bwd_convert_dq_trait_{F_idx},
{F_mode},
{F_deterministic}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>,
fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@
dataclass
class
FmhaBwdConvertQGradKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_bm0
:
int
# tile size along q seqlen (block size)
F_bn0
:
int
# tile size along k seqlen
F_rm
:
int
# number of warps along k seqlen (block warps) in gemm4
F_rn
:
int
# number of warps along q seqlen (block warps) in gemm4
F_rk
:
int
# number of warps along gemm-k (not used) in gemm4
F_wm
:
int
# warp size along m in gemm4
F_wn
:
int
# warp size along n in gemm4
F_wk
:
int
# warp size along k in gemm4
F_spad
:
str
# true/false
F_dpad
:
str
#
F_mode
:
str
# value from MODE_MAP
F_occupancy
:
int
#
F_deterministic
:
str
#
@
property
def
template
(
self
)
->
str
:
return
FMHA_BWD_KERNEL_HEADER
+
\
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_bm0
,
F_bn0
=
self
.
F_bn0
,
F_rm
=
self
.
F_rm
,
F_rn
=
self
.
F_rn
,
F_rk
=
self
.
F_rk
,
F_wm
=
self
.
F_wm
,
F_wn
=
self
.
F_wn
,
F_wk
=
self
.
F_wk
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dpad
=
BOOL_MAP
[
self
.
F_dpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_occupancy
=
self
.
F_occupancy
,
F_deterministic
=
BOOL_MAP
[
self
.
F_deterministic
])
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
"fmha_bwd_convert_dq_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
_r
{
self
.
F_rm
}
x
{
self
.
F_rn
}
x
{
self
.
F_rk
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
_
{
self
.
F_mode
}
_o
{
self
.
F_occupancy
}
"
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_deterministic
==
't'
:
n
+=
f
'_deterministic'
return
n
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
get_bwd_convert_dq_blobs
()
->
List
[
FmhaBwdConvertQGradKernel
]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def
get_occupancy
(
dtype
,
hdim
):
return
2
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
for
hdim_str
,
mode
,
spad
,
dpad
,
deterministic
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
hdim
=
int
(
hdim_str
)
tile
=
d
[
hdim_str
][
0
]
if
(
mode
==
"group"
and
spad
==
"f"
):
continue
k
=
FmhaBwdConvertQGradKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_bm0
=
64
,
F_bn0
=
tile
.
F_bn0
,
F_rm
=
tile
.
F_rm2
,
F_rn
=
tile
.
F_rn2
,
F_rk
=
tile
.
F_rk2
,
F_wm
=
tile
.
F_wm0
,
F_wn
=
tile
.
F_wn0
,
F_wk
=
tile
.
F_wk0
,
F_spad
=
spad
,
F_dpad
=
dpad
,
F_mode
=
mode
,
F_occupancy
=
get_occupancy
(
dtype
,
hdim
),
F_deterministic
=
deterministic
)
gen
.
append
(
k
)
return
gen
def
write_single_bwd_dq_dk_dv_kernel
(
kernel
:
FmhaBwdDQDKDVKernel
,
autogen_dir
:
Path
)
->
None
:
def
write_single_bwd_dq_dk_dv_kernel
(
kernel
:
FmhaBwdDQDKDVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_single_bwd_dot_do_o_kernel
(
kernel
:
FmhaBwdOGradDotOKernel
,
autogen_dir
:
Path
)
->
None
:
def
write_single_bwd_dot_do_o_kernel
(
kernel
:
FmhaBwdOGradDotOKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_single_bwd_convert_dq_kernel
(
kernel
:
FmhaBwdConvertQGradKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_bwd_api
(
api_pool
:
FmhaBwdApiPool
,
autogen_dir
:
Path
)
->
None
:
def
write_bwd_api
(
api_pool
:
FmhaBwdApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_BWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
(
autogen_dir
/
FMHA_BWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
...
@@ -595,6 +791,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
...
@@ -595,6 +791,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
kernels
=
get_bwd_dot_do_o_blobs
()
kernels
=
get_bwd_dot_do_o_blobs
()
for
kernel
in
kernels
:
for
kernel
in
kernels
:
write_single_bwd_dot_do_o_kernel
(
kernel
,
output_dir
)
write_single_bwd_dot_do_o_kernel
(
kernel
,
output_dir
)
kernels
=
get_bwd_convert_dq_blobs
()
for
kernel
in
kernels
:
write_single_bwd_convert_dq_kernel
(
kernel
,
output_dir
)
api_pool
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
api_pool
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
for
kernel
in
kernels
:
write_single_bwd_dq_dk_dv_kernel
(
kernel
,
output_dir
)
write_single_bwd_dq_dk_dv_kernel
(
kernel
,
output_dir
)
...
@@ -603,6 +802,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
...
@@ -603,6 +802,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
with
file_path
.
open
(
'a'
)
as
f
:
kernels
=
get_bwd_dot_do_o_blobs
()
kernels
=
get_bwd_dot_do_o_blobs
()
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
kernels
=
get_bwd_convert_dq_blobs
()
for
kernel
in
kernels
:
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment