Unverified Commit d10dfb57 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Full activation recompute checkpointing bug fix (#31)



fix checkpoint loading bug for FAR
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6d2294b2
......@@ -69,13 +69,13 @@ def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[torch.Tensor]]) -> None:
def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None:
"""Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU.
for index, deck in enumerate(buffer):
buffer[index] = deque([tensor.cuda() for tensor in deck])
buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck])
_fp8_tensors_recompute_buffer = buffer
......@@ -118,11 +118,11 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non
global _fp8_tensors_recompute_buffer
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = (
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
]
if buffer_position_key in fp8_meta:
_fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment