"tests/vscode:/vscode.git/clone" did not exist on "dab931a7aea1cc72fb480e20a083778aa4e44a4b"
Unverified Commit 7530b768 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Handle nested fp8 autocasts (#187)



Fixes in nested autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 36e0ac56
......@@ -276,7 +276,12 @@ def fp8_autocast(
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
fp8_state = (
_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE)
try:
_FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
......@@ -293,8 +298,12 @@ def fp8_autocast(
assert fp8_available, reason_for_no_fp8
yield
finally:
_FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_IS_FIRST_FP8_MODULE = False
(_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE) = fp8_state
_FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0:
......
......@@ -543,14 +543,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
if is_fp8_enabled() or is_fp8_calibration():
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment