Unverified Commit 40467fc2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Indexing fix for bug in virtual interleaved pipelining configs (#52)



Virtual interleaved pipelining fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent aadd3e7c
...@@ -115,10 +115,7 @@ def _prepare_backward(fp8: bool, ...@@ -115,10 +115,7 @@ def _prepare_backward(fp8: bool,
set_amax_buffer_key_deletion(fp8_meta, forward=False) set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key. # Get new backward key.
if "autocast_id_bwd" not in fp8_meta: fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False) add_amax_to_global_buffer(fp8_meta, forward=False)
...@@ -151,6 +148,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -151,6 +148,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
...@@ -402,6 +400,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -402,6 +400,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
add_amax_to_global_buffer(self.fp8_meta, forward=True) add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment