Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7d45045c
Commit
7d45045c
authored
Feb 12, 2025
by
Jim
Browse files
update: add restricts
parent
7857f621
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
8 deletions
+15
-8
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+15
-8
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
7d45045c
...
@@ -162,8 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
...
@@ -162,8 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_API_FILENAME
=
"fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE
=
"""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}"; }};
FMHA_BWD_V3_TEMPLATE
=
"""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}
{F_hdpad_name}
"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}
{F_hdpad_name}
; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
"""
"""
...
@@ -782,14 +782,18 @@ class FmhaBwdApiPool:
...
@@ -782,14 +782,18 @@ class FmhaBwdApiPool:
hdim
=
int
(
hdim
)
hdim
=
int
(
hdim
)
Ts_qo
=
32
if
hdim
==
64
else
16
Ts_qo
=
32
if
hdim
==
64
else
16
for
k
,
trait
in
enumerate
(
traits
):
for
k
,
trait
in
enumerate
(
traits
):
if
hdim
==
64
and
trait
.
is_hdpad
==
"t"
:
continue
hdim_name
=
"_hd64"
if
hdim
==
64
else
""
hdim_name
=
"_hd64"
if
hdim
==
64
else
""
dtype_name
=
"_{}"
.
format
(
dtype
)
dtype_name
=
"_{}"
.
format
(
dtype
)
causal_name
=
"_causal"
if
trait
.
is_causal
==
"t"
else
""
causal_name
=
"_causal"
if
trait
.
is_causal
==
"t"
else
""
atomic_name
=
"_a32"
if
trait
.
is_atomic
==
"t"
else
"_a16"
atomic_name
=
"_a32"
if
trait
.
is_atomic
==
"t"
else
"_a16"
bf16_cvt_name
=
"_{}"
.
format
(
BF16_CVT_MAP
[
trait
.
bf16_cvt
])
bf16_cvt_name
=
"_{}"
.
format
(
BF16_CVT_MAP
[
trait
.
bf16_cvt
])
bf16_cvt_name
=
bf16_cvt_name
if
dtype
==
"bf16"
else
""
hdpad_name
=
"_pddv"
if
trait
.
is_hdpad
==
"t"
else
""
gen_template
=
gen_template
+
FMHA_BWD_V3_TEMPLATE
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_atomic
=
BOOL_MAP
[
trait
.
is_atomic
],
gen_template
=
gen_template
+
FMHA_BWD_V3_TEMPLATE
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_atomic
=
BOOL_MAP
[
trait
.
is_atomic
],
F_is_causal
=
BOOL_MAP
[
trait
.
is_causal
],
F_bf16_cvt
=
trait
.
bf16_cvt
,
F_hdpad
=
BOOL_MAP
[
trait
.
is_hdpad
],
F_Ts_qo
=
Ts_qo
,
F_hdim_name
=
hdim_name
,
F_is_causal
=
BOOL_MAP
[
trait
.
is_causal
],
F_bf16_cvt
=
trait
.
bf16_cvt
,
F_hdpad
=
BOOL_MAP
[
trait
.
is_hdpad
],
F_Ts_qo
=
Ts_qo
,
F_hdim_name
=
hdim_name
,
F_dtype_name
=
dtype_name
,
F_causal_name
=
causal_name
,
F_atomic_name
=
atomic_name
,
F_bf16_cvt_name
=
bf16_cvt_name
)
F_dtype_name
=
dtype_name
,
F_causal_name
=
causal_name
,
F_atomic_name
=
atomic_name
,
F_bf16_cvt_name
=
bf16_cvt_name
,
F_hdpad_name
=
hdpad_name
)
v3_code
=
str
()
v3_code
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
...
@@ -1011,14 +1015,14 @@ class FmhaBwdV3DQDKDVKernel:
...
@@ -1011,14 +1015,14 @@ class FmhaBwdV3DQDKDVKernel:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
#
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
#
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'256'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
64
,
256
,
16
,
256
,
16
,
32
,
256
,
256
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
#
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
]
#
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
}
else
:
else
:
return
None
return
None
...
@@ -1061,8 +1065,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -1061,8 +1065,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
continue
if
receipt
==
3
:
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
]
cond
&=
dropout
in
[
'no'
]
cond
&=
dpad
==
dvpad
cond
&=
dpad
==
dvpad
cond
&=
spad
==
skpad
cond
&=
spad
==
"f"
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment