Unverified Commit 0e3e270f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Check if the given recipe is supported in `fp8_autocast` (#2073)



* check if the given recipe is supported in fp8_autocast
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* check only when enabled
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 988af0fd
......@@ -64,14 +64,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling()
......@@ -648,6 +660,8 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
if enabled:
check_recipe_support(fp8_recipe)
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(
enabled=enabled,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment