Unverified Commit 7530b768 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Handle nested fp8 autocasts (#187)



Fixes in nested autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 36e0ac56
...@@ -276,7 +276,12 @@ def fp8_autocast( ...@@ -276,7 +276,12 @@ def fp8_autocast(
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd global _amax_reduce_handle_fwd
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP) fp8_state = (
_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE)
try: try:
_FP8_ENABLED = enabled _FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating _FP8_CALIBRATION = calibrating
...@@ -293,8 +298,12 @@ def fp8_autocast( ...@@ -293,8 +298,12 @@ def fp8_autocast(
assert fp8_available, reason_for_no_fp8 assert fp8_available, reason_for_no_fp8
yield yield
finally: finally:
_FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state (_FP8_ENABLED,
_IS_FIRST_FP8_MODULE = False _FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE) = fp8_state
_FP8_AUTOCAST_DEPTH -= 1 _FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0: if _FP8_AUTOCAST_DEPTH == 0:
......
...@@ -543,14 +543,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -543,14 +543,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
if is_fp8_enabled() or is_fp8_calibration(): self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return return
# Set FP8, recipe, and other FP8 metadata # Set FP8, recipe, and other FP8 metadata
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group() self.fp8_meta["fp8_group"] = get_fp8_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment