Unverified Commit 681bf9ad authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Make amax reduction optional (#11)



* Make amax reduction optional
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove setup for global amax redux for optional case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve documentation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address documentation review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Documentation fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better FP8 checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Making checkpointing backwards compatible
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add deprecation warning for old checkpoint loading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix checkpointing for fp8 recompute case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improvements to deprecation warning
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent c149c145
......@@ -98,6 +98,14 @@ class DelayedScaling:
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
Notes
-----
......@@ -121,6 +129,7 @@ class DelayedScaling:
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent"
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
......
......@@ -5,7 +5,7 @@
"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque
import torch
import transformer_engine_extensions as tex
......@@ -64,6 +64,22 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
_global_fp8_buffer = buffer
def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 recompute buffer."""
return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[torch.Tensor]]) -> None:
"""Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU.
for index, deck in enumerate(buffer):
buffer[index] = deque([tensor.cuda() for tensor in deck])
_fp8_tensors_recompute_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func
......
......@@ -31,6 +31,8 @@ from .fp8 import (
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
get_global_fp8_recompute_buffer,
set_global_fp8_recompute_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
......@@ -145,17 +147,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> Union[List[Any], None]:
"""Save before checkpointing."""
if self.fp8:
state = []
state.append(self.fp8_meta["scaling_fwd"].scale)
state.append(self.fp8_meta["scaling_fwd"].amax_history)
state.append(self.fp8_meta["scaling_bwd"].scale)
state.append(self.fp8_meta["scaling_bwd"].amax_history)
state.append(get_global_fp8_buffer())
state.append(self.fp8_meta["update_amax_and_scale_fwd"])
state.append(self.fp8_meta["global_fp8_buffer_pos_fwd"])
state.append(self.fp8_meta["global_fp8_buffer_pos_bwd"])
state.append(self.fp8_meta["autocast_id_fwd"])
state.append(self.fp8_meta["autocast_id_bwd"])
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
state["global_fp8_recompute_buffer"] = get_global_fp8_recompute_buffer()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
extra[k] = v
state["extra_fp8_variables"] = extra
return state
return None
......@@ -164,32 +170,56 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
self.fp8_meta_tensors_initialized = True
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
# Initialize before loading.
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta_tensors_initialized = True
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
......@@ -310,22 +340,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
# Previous iteration was grad_enabled
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
if self.fp8 and self.training:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
amax_and_scale_update(self.fp8_meta, True)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
......@@ -349,7 +382,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training:
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
......@@ -366,6 +399,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if not fp8:
return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
......@@ -387,7 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8:
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment