Unverified Commit d10dfb57 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Full activation recompute checkpointing bug fix (#31)



fix checkpoint loading bug for FAR
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6d2294b2
...@@ -69,13 +69,13 @@ def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]: ...@@ -69,13 +69,13 @@ def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
return _fp8_tensors_recompute_buffer return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[torch.Tensor]]) -> None: def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None:
"""Sets global fp8 recompute buffer.""" """Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU. # Map all tensors back to GPU.
for index, deck in enumerate(buffer): for index, deck in enumerate(buffer):
buffer[index] = deque([tensor.cuda() for tensor in deck]) buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck])
_fp8_tensors_recompute_buffer = buffer _fp8_tensors_recompute_buffer = buffer
...@@ -118,11 +118,11 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non ...@@ -118,11 +118,11 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non
global _fp8_tensors_recompute_buffer global _fp8_tensors_recompute_buffer
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = ( to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(), fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(), fp8_meta["scaling_fwd"].scale_inv.clone(),
) ]
if buffer_position_key in fp8_meta: if buffer_position_key in fp8_meta:
_fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) _fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment