Unverified Commit 08779fd8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix FP8 current scaling attention logic (#2234)



* Fix in FP8 attention selection logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5be81251
......@@ -597,9 +597,10 @@ class DotProductAttention(TransformerEngineBaseModule):
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in (
"",
"Float8CurrentScaling",
elif (
fp8_recipe.float8_current_scaling()
and _dpa_fp8_recipe in ("", "Float8CurrentScaling")
and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha)
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment