Unverified Commit 681bf9ad authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Make amax reduction optional (#11)



* Make amax reduction optional
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove setup for global amax redux for optional case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve documentation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address documentation review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Documentation fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better FP8 checkpointing
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Making checkpointing backwards compatible
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add deprecation warning for old checkpoint loading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix checkpointing for fp8 recompute case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* improvements to deprecation warning
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent c149c145
...@@ -98,6 +98,14 @@ class DelayedScaling: ...@@ -98,6 +98,14 @@ class DelayedScaling:
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad` Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8. GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
Notes Notes
----- -----
...@@ -121,6 +129,7 @@ class DelayedScaling: ...@@ -121,6 +129,7 @@ class DelayedScaling:
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent" amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent"
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""FP8 utilies for TransformerEngine""" """FP8 utilies for TransformerEngine"""
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -64,6 +64,22 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None: ...@@ -64,6 +64,22 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
_global_fp8_buffer = buffer _global_fp8_buffer = buffer
def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 recompute buffer."""
return _fp8_tensors_recompute_buffer
def set_global_fp8_recompute_buffer(buffer: List[Deque[torch.Tensor]]) -> None:
"""Sets global fp8 recompute buffer."""
global _fp8_tensors_recompute_buffer
# Map all tensors back to GPU.
for index, deck in enumerate(buffer):
buffer[index] = deque([tensor.cuda() for tensor in deck])
_fp8_tensors_recompute_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None: def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit.""" """Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func global _amax_forward_global_reduce_func
......
...@@ -31,6 +31,8 @@ from .fp8 import ( ...@@ -31,6 +31,8 @@ from .fp8 import (
amax_and_scale_update, amax_and_scale_update,
get_global_fp8_buffer, get_global_fp8_buffer,
set_global_fp8_buffer, set_global_fp8_buffer,
get_global_fp8_recompute_buffer,
set_global_fp8_recompute_buffer,
set_amax_buffer_key_deletion, set_amax_buffer_key_deletion,
delete_key_from_amax_buffer, delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute, copy_forward_fp8_meta_tensors_for_recompute,
...@@ -145,17 +147,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -145,17 +147,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> Union[List[Any], None]: def get_extra_state(self) -> Union[List[Any], None]:
"""Save before checkpointing.""" """Save before checkpointing."""
if self.fp8: if self.fp8:
state = [] state = {}
state.append(self.fp8_meta["scaling_fwd"].scale) state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state.append(self.fp8_meta["scaling_fwd"].amax_history) state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state.append(self.fp8_meta["scaling_bwd"].scale) state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state.append(self.fp8_meta["scaling_bwd"].amax_history) state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state.append(get_global_fp8_buffer()) state["global_fp8_buffer"] = get_global_fp8_buffer()
state.append(self.fp8_meta["update_amax_and_scale_fwd"]) state["global_fp8_recompute_buffer"] = get_global_fp8_recompute_buffer()
state.append(self.fp8_meta["global_fp8_buffer_pos_fwd"])
state.append(self.fp8_meta["global_fp8_buffer_pos_bwd"]) # Store other pickelable values.
state.append(self.fp8_meta["autocast_id_fwd"]) extra = {}
state.append(self.fp8_meta["autocast_id_bwd"]) for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
extra[k] = v
state["extra_fp8_variables"] = extra
return state return state
return None return None
...@@ -164,6 +170,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -164,6 +170,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Retrieve checkpointed items. # Retrieve checkpointed items.
scale_fwd = state[0] scale_fwd = state[0]
amax_history_fwd = state[1] amax_history_fwd = state[1]
...@@ -189,6 +202,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -189,6 +202,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7] self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8] self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9] self.fp8_meta["autocast_id_bwd"] = state[9]
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
# Initialize before loading.
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta_tensors_initialized = True
def set_activation_dtype(self, inp: torch.Tensor) -> None: def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
...@@ -310,21 +340,24 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -310,21 +340,24 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
add_amax_to_global_buffer(self.fp8_meta, forward=True) add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
...@@ -349,7 +382,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -349,7 +382,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
restore_fp8_meta_tensors(self.fp8_meta) restore_fp8_meta_tensors(self.fp8_meta)
return return
if self.fp8 and self.training: if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
global_amax_reduction, global_amax_reduction,
...@@ -366,6 +399,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -366,6 +399,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if not fp8: if not fp8:
return return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration # From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False) copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
...@@ -387,7 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -387,7 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
) -> None: ) -> None:
"""Checks and prep for BWD.""" """Checks and prep for BWD."""
if not fp8: if not fp8 or not fp8_meta["recipe"].reduce_amax:
return return
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment