Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
86923c19
Commit
86923c19
authored
Feb 11, 2025
by
Jim
Browse files
update
parent
8b745f2c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6075 additions
and
54 deletions
+6075
-54
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+33
-0
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+138
-53
example/ck_tile/01_fmha/fmha_fwd_api.cpp
example/ck_tile/01_fmha/fmha_fwd_api.cpp
+5903
-0
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+1
-1
No files found.
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
86923c19
...
@@ -127,8 +127,41 @@ BOOL_MAP = {
...
@@ -127,8 +127,41 @@ BOOL_MAP = {
"f"
:
"false"
"f"
:
"false"
}
}
BWD_V3_HDIM_MAP
=
{
"64"
:
"64"
,
"128"
:
"128"
}
BF16_CVT_MAP
=
{
BF16_CVT_MAP
=
{
0
:
"rtne"
,
0
:
"rtne"
,
1
:
"rtna"
,
1
:
"rtna"
,
2
:
"rtz"
,
2
:
"rtz"
,
}
}
BWD_V3_MASK_MAP
=
{
"t"
:
"((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0)))"
,
"f"
:
"(t.mask_type == mask_enum::no_mask)"
}
BWD_V3_ATOMIC32_MAP
=
{
"t"
:
"((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/))"
,
"f"
:
"(t.is_v3_atomic_fp32 == false)"
}
BWD_V3_HDIM_CASE_MAP
=
{
0
:
"(a.hdim_q == 128)"
,
1
:
"(a.hdim_q == 64)"
,
2
:
"((a.hdim_q > 64) && (a.hdim_q < 128))"
}
BWD_V3_HDIM_CASE_CHECK_MAP
=
{
0
:
128
,
1
:
64
,
2
:
128
}
BWD_V3_PADDING_CHECK_MAP
=
{
0
:
"false"
,
1
:
"false"
,
2
:
"true"
}
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
86923c19
...
@@ -162,9 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
...
@@ -162,9 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE
=
"""
FMHA_BWD_V3_TEMPLATE
=
"""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
"""
"""
...
@@ -589,7 +588,7 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -589,7 +588,7 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
if
(
(t.uses_bwd_v3 == true){{
if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) &&
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) &&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
...
@@ -623,37 +622,39 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
...
@@ -623,37 +622,39 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
}}
}}
"""
"""
FMHA_V3_DISPATCH
=
"""
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
=
""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
if (t.mask_type == mask_enum::no_mask){{
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;"""
}}
else if (t.is_v3_atomic_fp32 == false){{
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
=
""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;"""
}}
FMHA_BWD_V3_PER_DTYPE_CASE
=
""" {F_if} (t.data_type.compare(
\"
{F_dtype}
\"
) == 0) {{
{per_bf16_cvt_dispatch}
}}
}}
else if ((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
"""
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
FMHA_BWD_V3_PER_BF16_CVT_CASE
=
""" {F_if} (t.how_v3_bf16_cvt == {F_bf16_cvt}) {{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
{per_mask_dispatch}
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if (t.is_v3_atomic_fp32 == false){{
"""
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
FMHA_BWD_V3_PER_MASK_CASE
=
""" {F_if} {F_mask_expression}{{
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
{per_atomic_dispatch}
return r;
}}
}}
"""
FMHA_BWD_V3_PER_ATOMIC_CASE
=
""" {F_if} {F_atomic_expression}{{
{per_hdim_dispatch}
}}
}}
"""
FMHA_BWD_V3_PER_HDIM_CASE
=
""" {F_if} {F_hdim_expression}{{
{inner_dispatch}
}}
}}
"""
"""
...
@@ -714,11 +715,18 @@ class FmhaBwdV3DQDKDVApiTrait:
...
@@ -714,11 +715,18 @@ class FmhaBwdV3DQDKDVApiTrait:
is_causal
:
str
is_causal
:
str
is_atomic
:
str
is_atomic
:
str
bf16_cvt
:
int
bf16_cvt
:
int
hd_pad
:
str
is_hdpad
:
str
def
remap_hdim
(
self
):
hdim_int
=
int
(
self
.
hdim
)
if
hdim_int
>
64
:
self
.
hdim
=
128
hdim_int
=
(
hdim_int
+
64
-
1
)
/
64
*
64
class
FmhaBwdApiPool
:
class
FmhaBwdApiPool
:
def
__init__
(
self
,
mask_impl
):
def
__init__
(
self
,
mask_impl
):
self
.
dq_dk_dv_pool
=
dict
()
self
.
dq_dk_dv_pool
=
dict
()
self
.
dq_dk_dv_v3_pool
=
dict
()
self
.
mask_impl
=
mask_impl
self
.
mask_impl
=
mask_impl
def
register_dq_dk_dv_traits
(
self
,
trait
:
FmhaBwdDQDKDVApiTrait
)
->
None
:
def
register_dq_dk_dv_traits
(
self
,
trait
:
FmhaBwdDQDKDVApiTrait
)
->
None
:
...
@@ -730,6 +738,15 @@ class FmhaBwdApiPool:
...
@@ -730,6 +738,15 @@ class FmhaBwdApiPool:
self
.
dq_dk_dv_pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
self
.
dq_dk_dv_pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
def
register_dq_dk_dv_v3_traits
(
self
,
trait
:
FmhaBwdV3DQDKDVApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
dq_dk_dv_v3_pool
.
keys
():
self
.
dq_dk_dv_v3_pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
dq_dk_dv_v3_pool
[
trait
.
dtype
].
keys
():
self
.
dq_dk_dv_v3_pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
dq_dk_dv_v3_pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
@
property
def
api
(
self
)
->
str
:
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
per_dtypes
=
str
()
...
@@ -758,28 +775,51 @@ class FmhaBwdApiPool:
...
@@ -758,28 +775,51 @@ class FmhaBwdApiPool:
# empty string we add some ignore to suppress warning in api
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
v3_code
=
str
()
gen_template
=
str
()
gen_template
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_pool
.
keys
()):
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
for
j
,
hdim
in
enumerate
(
self
.
dq_dk_dv_pool
[
dtype
].
keys
()):
for
j
,
hdim
in
enumerate
(
BWD_V3_HDIM_MAP
.
keys
()):
traits
=
self
.
dq_dk_dv_pool
[
dtype
][
hdim
]
traits
=
self
.
dq_dk_dv_v3_pool
[
dtype
][
hdim
]
hdim_int
=
int
(
hdim
)
hdim
=
int
(
hdim
)
hdim_int
=
(
hdim_int
+
64
-
1
)
/
64
*
64
Ts_qo
=
32
if
hdim
==
64
else
16
Ts_qo
=
32
if
hdim
==
64
else
16
for
k
,
trait
in
enumerate
(
traits
):
for
k
,
trait
in
enumerate
(
traits
):
padding
=
"t"
if
hdim_int
%
64
==
0
else
"f"
hdim_name
=
"_hd64"
if
hdim
==
64
else
""
padding_suffix
=
"_hdp"
if
padding
==
"t"
else
""
dtype_name
=
"_{}"
.
format
(
dtype
)
v3_code
=
v3_code
+
FMHA_V3_DISPATCH
.
format
(
F_hdim
=
hdim_int
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_padding
=
BOOL_MAP
[
padding
],
causal_name
=
"_causal"
if
trait
.
is_causal
==
"t"
else
""
F_is_atomic32
=
BOOL_MAP
[
trait
.
is_atomic
],
F_how_v3_bf16_cvt
=
trait
.
bf16_cvt
,
F_padding_suffix
=
padding_suffix
)
atomic_name
=
"_a32"
if
trait
.
is_atomic
==
"t"
else
"_a16"
bf16_cvt_name
=
"_{}"
.
format
(
BF16_CVT_MAP
[
trait
.
bf16_cvt
])
hdim_name
=
"hd64"
if
hdim_int
==
64
else
""
gen_template
=
gen_template
+
FMHA_BWD_V3_TEMPLATE
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_atomic
=
BOOL_MAP
[
trait
.
is_atomic
],
dtype_name
=
dtype
F_is_causal
=
BOOL_MAP
[
trait
.
is_causal
],
F_bf16_cvt
=
trait
.
bf16_cvt
,
F_hdpad
=
BOOL_MAP
[
trait
.
is_hdpad
],
F_Ts_qo
=
Ts_qo
,
F_hdim_name
=
hdim_name
,
causal_name
=
"causal"
if
trait
.
is_causal
==
"t"
else
""
atomic_name
=
"a32"
if
trait
.
is_atomic
==
"t"
else
"a16"
bf16_cvt_name
=
BF16_CVT_MAP
[
trait
.
bf16_cvt
]
gen_template
=
gen_template
+
FMHA_BWD_V3_TEMPLATE
.
format
(
F_hdim
=
hdim_int
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_atomic
=
BOOL_MAP
[
trait
.
is_atomic
],
F_is_causal
=
BOOL_MAP
[
trait
.
is_causal
],
F_bf16_cvt
=
trait
.
bf16_cvt
,
F_hdpad
=
BOOL_MAP
[
padding
],
F_Ts_qo
=
Ts_qo
,
F_hdim_name
=
hdim_name
,
F_dtype_name
=
dtype_name
,
F_causal_name
=
causal_name
,
F_atomic_name
=
atomic_name
,
F_bf16_cvt_name
=
bf16_cvt_name
)
F_dtype_name
=
dtype_name
,
F_causal_name
=
causal_name
,
F_atomic_name
=
atomic_name
,
F_bf16_cvt_name
=
bf16_cvt_name
)
v3_code
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
per_bf16
=
str
()
for
j
,
bf16_cvt
in
enumerate
([
0
,
1
,
2
]):
per_mask
=
str
()
if
(
dtype
==
"fp16"
)
and
(
bf16_cvt
in
[
1
,
2
]):
continue
for
k
,
is_causal
in
enumerate
([
"t"
,
"f"
]):
per_atomic
=
str
()
for
l
,
is_atomic
in
enumerate
([
"t"
,
"f"
]):
per_hdim
=
str
()
for
m
,
hdim
in
enumerate
(
BWD_V3_HDIM_CASE_CHECK_MAP
.
values
()):
if_m
=
'if'
if
m
==
0
else
'else if'
inners
=
str
()
if
is_atomic
==
"t"
:
inners
=
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
else
:
inners
=
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
per_hdim
=
per_hdim
+
FMHA_BWD_V3_PER_HDIM_CASE
.
format
(
F_if
=
if_m
,
F_hdim_expression
=
BWD_V3_HDIM_CASE_MAP
[
m
],
inner_dispatch
=
inners
)
if_l
=
'if'
if
l
==
0
else
'else if'
per_atomic
=
per_atomic
+
FMHA_BWD_V3_PER_ATOMIC_CASE
.
format
(
F_if
=
if_l
,
F_atomic_expression
=
BWD_V3_ATOMIC32_MAP
[
is_atomic
],
per_hdim_dispatch
=
per_hdim
)
if_k
=
'if'
if
k
==
0
else
'else if'
per_mask
=
per_mask
+
FMHA_BWD_V3_PER_MASK_CASE
.
format
(
F_if
=
if_k
,
F_mask_expression
=
BWD_V3_MASK_MAP
[
is_causal
],
per_atomic_dispatch
=
per_atomic
)
if_j
=
'if'
if
j
==
0
else
'else if'
per_bf16
=
per_bf16
+
FMHA_BWD_V3_PER_BF16_CVT_CASE
.
format
(
F_if
=
if_j
,
F_bf16_cvt
=
bf16_cvt
,
per_mask_dispatch
=
per_mask
)
if_i
=
'if'
if
i
==
0
else
'else if'
v3_code
=
v3_code
+
FMHA_BWD_V3_PER_DTYPE_CASE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
per_bf16_cvt_dispatch
=
per_bf16
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
,
F_template
=
gen_template
,
F_v3_dispatch
=
v3_code
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
,
F_template
=
gen_template
,
F_v3_dispatch
=
v3_code
)
# GEMM0: Q@K=S^T
# GEMM0: Q@K=S^T
...
@@ -932,6 +972,40 @@ class FmhaBwdDQDKDVKernel:
...
@@ -932,6 +972,40 @@ class FmhaBwdDQDKDVKernel:
deterministic
=
self
.
F_deterministic
deterministic
=
self
.
F_deterministic
)
)
@
dataclass
class
FmhaBwdV3DQDKDVKernel
:
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_is_causal
:
str
F_is_atomic
:
str
F_bf16_cvt
:
int
F_is_hdpad
:
str
# @property
# def gen_bwd_v3_template(self) -> str:
# hdim_template = 64 if self.F_hdim == 64 else 128
# Ts_qo = 32 if hdim_template == 64 else 16
# padding = "t" if self.F_hdim % 64 == 0 else "f"
# padding_suffix = "_hdp" if padding == "t" else ""
# hdim_name = "_hd64" if hdim_template == 64 else ""
# dtype_name = "_{}".format(self.F_dtype)
# causal_name = "_causal" if self.F_is_causal == "t" else ""
# atomic_name = "_a32" if self.F_is_atomic == "t" else "_a16"
# bf16_cvt_name = "_{}".format(BF16_CVT_MAP[self.F_bf16_cvt])
# return FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim_template, F_dtype=BWD_DTYPE_MAP[self.F_dtype], F_is_atomic=BOOL_MAP[self.F_is_atomic],
# F_is_causal=BOOL_MAP[self.F_is_causal], F_bf16_cvt=self.F_bf16_cvt, F_hdpad=BOOL_MAP[padding], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
# F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name)
def
v3_api_trait
(
self
)
->
FmhaBwdV3DQDKDVApiTrait
:
return
FmhaBwdV3DQDKDVApiTrait
(
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
is_causal
=
self
.
F_is_causal
,
is_atomic
=
self
.
F_is_atomic
,
bf16_cvt
=
self
.
F_bf16_cvt
,
is_hdpad
=
self
.
F_is_hdpad
)
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
# this is current supported tile size & pipeline.
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
...
@@ -995,6 +1069,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -995,6 +1069,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
for
hdim_str
,
is_causal
,
is_atomic
,
bf16_cvt
,
is_hdpad
in
itertools
.
product
(
d
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
],
[
0
,
1
,
2
],
[
"t"
,
"f"
]):
hdim
=
int
(
hdim_str
)
k
=
FmhaBwdV3DQDKDVKernel
(
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_is_causal
=
is_causal
,
F_is_atomic
=
is_atomic
,
F_bf16_cvt
=
bf16_cvt
,
F_is_hdpad
=
is_hdpad
)
if
receipt
==
3
:
cond
=
(
dtype
==
'fp16'
)
and
(
bf16_cvt
in
[
1
,
2
])
if
cond
:
# print(dtype, bf16_cvt)
continue
api_pool
.
register_dq_dk_dv_v3_traits
(
k
.
v3_api_trait
())
# gen.append(k)
return
(
api_pool
,
gen
)
return
(
api_pool
,
gen
)
FMHA_BWD_DOT_DO_O_KERNEL_BODY
=
"""
FMHA_BWD_DOT_DO_O_KERNEL_BODY
=
"""
...
...
example/ck_tile/01_fmha/fmha_fwd_api.cpp
0 → 100644
View file @
86923c19
This source diff could not be displayed because it is too large. You can
view the blob
instead.
example/ck_tile/01_fmha/generate.py
View file @
86923c19
...
@@ -99,7 +99,7 @@ if __name__ == "__main__":
...
@@ -99,7 +99,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"-r"
,
"-r"
,
"--receipt"
,
"--receipt"
,
default
=
0
,
default
=
3
,
required
=
False
,
required
=
False
,
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment