"vscode:/vscode.git/clone" did not exist on "0967102c6de874902d973f0bcb98d48149a8eb49"
Unverified Commit 40467fc2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Indexing fix for bug in virtual interleaved pipelining configs (#52)



Virtual interleaved pipelining fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent aadd3e7c
......@@ -115,10 +115,7 @@ def _prepare_backward(fp8: bool,
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False)
......@@ -151,6 +148,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
......@@ -402,6 +400,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment