Unverified Commit 0e3e270f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Check if the given recipe is supported in `fp8_autocast` (#2073)



* check if the given recipe is supported in fp8_autocast
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* check only when enabled
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 988af0fd
...@@ -64,14 +64,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ...@@ -64,14 +64,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe: def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args.""" """FP8 recipe with default args."""
if check_mxfp8_support()[0]: if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling() return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling() return DelayedScaling()
...@@ -648,6 +660,8 @@ def fp8_autocast( ...@@ -648,6 +660,8 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
if enabled:
check_recipe_support(fp8_recipe)
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter( FP8GlobalStateManager.fp8_autocast_enter(
enabled=enabled, enabled=enabled,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment