Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
14b4d6bb
"doc/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "2e754b87af8b2bdb3dfb4e6838426ad7654ba591"
Commit
14b4d6bb
authored
Feb 12, 2025
by
Jim
Browse files
fix fp16 case
parent
9f24a7ed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
7 deletions
+6
-7
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+6
-7
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
14b4d6bb
...
@@ -800,11 +800,9 @@ class FmhaBwdApiPool:
...
@@ -800,11 +800,9 @@ class FmhaBwdApiPool:
v3_code
=
str
()
v3_code
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
for
i
,
dtype
in
enumerate
(
self
.
dq_dk_dv_v3_pool
.
keys
()):
per_bf16
=
str
()
per_bf16
_cvt
=
str
()
for
j
,
bf16_cvt
in
enumerate
([
0
,
1
,
2
]):
for
j
,
bf16_cvt
in
enumerate
([
0
,
1
,
2
]):
per_mask
=
str
()
per_mask
=
str
()
if
(
dtype
==
"fp16"
)
and
(
bf16_cvt
in
[
1
,
2
]):
continue
for
k
,
is_causal
in
enumerate
([
"t"
,
"f"
]):
for
k
,
is_causal
in
enumerate
([
"t"
,
"f"
]):
per_atomic
=
str
()
per_atomic
=
str
()
for
l
,
is_atomic
in
enumerate
([
"t"
,
"f"
]):
for
l
,
is_atomic
in
enumerate
([
"t"
,
"f"
]):
...
@@ -812,10 +810,11 @@ class FmhaBwdApiPool:
...
@@ -812,10 +810,11 @@ class FmhaBwdApiPool:
for
m
,
hdim
in
enumerate
(
BWD_V3_HDIM_CASE_CHECK_MAP
.
values
()):
for
m
,
hdim
in
enumerate
(
BWD_V3_HDIM_CASE_CHECK_MAP
.
values
()):
if_m
=
'if'
if
m
==
0
else
'else if'
if_m
=
'if'
if
m
==
0
else
'else if'
inners
=
str
()
inners
=
str
()
bf16_cvt_tmp
=
0
if
dtype
==
"fp16"
else
bf16_cvt
if
is_atomic
==
"t"
:
if
is_atomic
==
"t"
:
inners
=
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
inners
=
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
else
:
else
:
inners
=
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
inners
=
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt
_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
per_hdim
=
per_hdim
+
FMHA_BWD_V3_PER_HDIM_CASE
.
format
(
F_if
=
if_m
,
F_hdim_expression
=
BWD_V3_HDIM_CASE_MAP
[
m
],
inner_dispatch
=
inners
)
per_hdim
=
per_hdim
+
FMHA_BWD_V3_PER_HDIM_CASE
.
format
(
F_if
=
if_m
,
F_hdim_expression
=
BWD_V3_HDIM_CASE_MAP
[
m
],
inner_dispatch
=
inners
)
if_l
=
'if'
if
l
==
0
else
'else if'
if_l
=
'if'
if
l
==
0
else
'else if'
...
@@ -823,9 +822,9 @@ class FmhaBwdApiPool:
...
@@ -823,9 +822,9 @@ class FmhaBwdApiPool:
if_k
=
'if'
if
k
==
0
else
'else if'
if_k
=
'if'
if
k
==
0
else
'else if'
per_mask
=
per_mask
+
FMHA_BWD_V3_PER_MASK_CASE
.
format
(
F_if
=
if_k
,
F_mask_expression
=
BWD_V3_MASK_MAP
[
is_causal
],
per_atomic_dispatch
=
per_atomic
)
per_mask
=
per_mask
+
FMHA_BWD_V3_PER_MASK_CASE
.
format
(
F_if
=
if_k
,
F_mask_expression
=
BWD_V3_MASK_MAP
[
is_causal
],
per_atomic_dispatch
=
per_atomic
)
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_bf16
=
per_bf16
+
FMHA_BWD_V3_PER_BF16_CVT_CASE
.
format
(
F_if
=
if_j
,
F_bf16_cvt
=
bf16_cvt
,
per_mask_dispatch
=
per_mask
)
per_bf16
_cvt
=
per_bf16
_cvt
+
FMHA_BWD_V3_PER_BF16_CVT_CASE
.
format
(
F_if
=
if_j
,
F_bf16_cvt
=
bf16_cvt
,
per_mask_dispatch
=
per_mask
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
v3_code
=
v3_code
+
FMHA_BWD_V3_PER_DTYPE_CASE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
per_bf16_cvt_dispatch
=
per_bf16
)
v3_code
=
v3_code
+
FMHA_BWD_V3_PER_DTYPE_CASE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
per_bf16_cvt_dispatch
=
per_bf16
_cvt
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
,
F_template
=
gen_template
,
F_v3_dispatch
=
v3_code
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
,
F_template
=
gen_template
,
F_v3_dispatch
=
v3_code
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment