Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f8b14618
Commit
f8b14618
authored
Jul 30, 2024
by
danyao12
Browse files
fix hd64 scratches and boost performance
parent
5d2a5a11
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
4 deletions
+2
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+2
-4
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
f8b14618
...
@@ -278,8 +278,6 @@ class FmhaBwdApiPool:
...
@@ -278,8 +278,6 @@ class FmhaBwdApiPool:
for
spad1
in
[
"t"
,
"f"
]:
for
spad1
in
[
"t"
,
"f"
]:
if
(
spad1
==
"f"
and
(
trait
.
spad
==
"t"
or
trait
.
mode
==
"group"
)):
if
(
spad1
==
"f"
and
(
trait
.
spad
==
"t"
or
trait
.
mode
==
"group"
)):
continue
continue
if
(
spad1
==
"t"
and
trait
.
spad
==
"f"
and
hdim_int
==
64
):
continue
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
...
@@ -453,7 +451,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
...
@@ -453,7 +451,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
return
{
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr"
],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
64
,
128
,
64
,
64
,
64
,
64
,
64
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
1
),
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr"
],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr"
],
...
@@ -481,7 +479,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -481,7 +479,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
continue
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
continue
continue
if
(((
hdim
==
64
)
and
(
"wg16"
in
dropout
))
or
((
hdim
!=
64
)
and
(
"wg32"
in
dropout
)
))
:
if
(
"wg32"
in
dropout
):
continue
continue
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment