Unverified Commit 2e0bfbd9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] FP8 fixes (#380)



* Initial refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reorder methods by purpose
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save full global state
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes to test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cbfb8c6b
...@@ -3,6 +3,7 @@ extension-pkg-whitelist=torch, ...@@ -3,6 +3,7 @@ extension-pkg-whitelist=torch,
transformer_engine_extensions transformer_engine_extensions
disable=too-many-locals, disable=too-many-locals,
too-many-public-methods,
invalid-name, invalid-name,
too-many-arguments, too-many-arguments,
abstract-method, abstract-method,
......
...@@ -10,7 +10,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -10,7 +10,7 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal, scaled_init_method_normal,
get_device_compute_capability, get_device_compute_capability,
) )
from transformer_engine.pytorch.fp8 import is_fp8_available from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
import os import os
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
from pkg_resources import packaging from pkg_resources import packaging
from importlib.metadata import version from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
......
...@@ -38,7 +38,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp ...@@ -38,7 +38,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Global test configuration knobs. # Global test configuration knobs.
...@@ -66,7 +66,7 @@ assert OPSET >= TRILU_OPSET ...@@ -66,7 +66,7 @@ assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import pytest import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch import ( ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def custom_amax_to_scale( def custom_amax_to_scale(
......
...@@ -12,7 +12,7 @@ from torch.utils.checkpoint import detach_variable ...@@ -12,7 +12,7 @@ from torch.utils.checkpoint import detach_variable
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import is_fp8_enabled from .fp8 import FP8GlobalStateManager
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False, "tensor_model_parallel": False,
...@@ -145,7 +145,8 @@ def activation_recompute_forward( ...@@ -145,7 +145,8 @@ def activation_recompute_forward(
""" """
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try: try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled() _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase _FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield yield
finally: finally:
......
...@@ -16,28 +16,11 @@ from .constants import dist_group_type ...@@ -16,28 +16,11 @@ from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
_FP8_ENABLED = False
_FP8_CALIBRATION = False __all__ = ["fp8_autocast"]
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False def check_fp8_support() -> Tuple[bool, str]:
_FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {}
_fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
_dp_amax_reduce_interval = None
_dp_amax_reduce_forward_idx = 0
_dp_amax_reduce_backward_idx = 0
def _check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above if get_device_compute_capability() >= 9.0: # hopper and above
return True, "" return True, ""
...@@ -50,99 +33,374 @@ def _check_fp8_support() -> Tuple[bool, str]: ...@@ -50,99 +33,374 @@ def _check_fp8_support() -> Tuple[bool, str]:
return True, "" return True, ""
def is_fp8_available() -> Tuple[bool, str]: def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED = False
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
FP8_AUTOCAST_DEPTH = 0
global_fp8_buffer = {}
fp8_tensors_recompute_buffer = []
amax_forward_global_reduce_func = None
buffer_delete_key_fwd = None
buffer_delete_key_bwd = None
amax_reduce_handle_fwd = None
fp8_available = None
reason_for_no_fp8 = ""
dp_amax_reduce_interval = None
dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
global _is_fp8_available, _reason_for_no_fp8 if cls.fp8_available is None:
if _is_fp8_available is None: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8
return _is_fp8_available, _reason_for_no_fp8
@classmethod
def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]:
"""Returns global fp8 state variables."""
# Convert attributes to dictionary to make future proof against
# changes in global state variables in order to make setting the
# checkpoint backwards compatible.
global_fp8_state = {}
global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER
global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID
global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH
global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd
global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd
global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval
global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx
global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx
return global_fp8_state
@classmethod
def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None:
"""Sets global fp8 state variables."""
for k, v in state.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 amax buffer."""
return cls.global_fp8_buffer
@classmethod
def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 amax buffer."""
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
cls.global_fp8_buffer = buffer
def get_meta_tensor_key(forward: bool = True) -> str: @staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`.""" """Returns scaling key in `fp8_meta`."""
if forward: if forward:
return "scaling_fwd" return "scaling_fwd"
return "scaling_bwd" return "scaling_bwd"
@staticmethod
def get_buffer_position_key(forward: bool = True) -> str: def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`.""" """Returns module position key in `fp8_meta`."""
if forward: if forward:
return "global_fp8_buffer_pos_fwd" return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd" return "global_fp8_buffer_pos_bwd"
@staticmethod
def get_autocast_key(forward: bool = True) -> str: def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`.""" """Returns module position key in `fp8_meta`."""
if forward: if forward:
return "autocast_id_fwd" return "autocast_id_fwd"
return "autocast_id_bwd" return "autocast_id_bwd"
@staticmethod
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
def get_amax_reduce_handle_fwd() -> Union[bool, None]: @classmethod
def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop.""" """Return AMAX reduction wait handle of forward prop."""
global _amax_reduce_handle_fwd return cls.amax_reduce_handle_fwd
return _amax_reduce_handle_fwd
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]: @classmethod
"""Returns global fp8 buffer.""" def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None:
return _global_fp8_buffer """Sets up the function to call during autocast exit."""
cls.amax_forward_global_reduce_func = f
@classmethod
def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None: if buffer_key not in cls.global_fp8_buffer:
"""Sets global fp8 buffer.""" cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
global _global_fp8_buffer else:
cls.global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
# Map all tensors back to GPU. if buffer_position_key not in fp8_meta:
for k, v in buffer.items(): fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1
buffer[k] = [tensor.cuda() for tensor in v]
_global_fp8_buffer = buffer # Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently" \
" unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
@classmethod
def copy_amax_from_global_buffer(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error."
def setup_amax_forward_global_reduce_func(f: Callable) -> None: fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][
"""Sets up the function to call during autocast exit.""" fp8_meta[buffer_position_key]
global _amax_forward_global_reduce_func ]
_amax_forward_global_reduce_func = f
@classmethod
def set_amax_buffer_key_deletion(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if cls.get_autocast_key(forward=forward) not in fp8_meta:
return
if forward:
cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
else:
cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str: @classmethod
"""Return a key in `_global_fp8_buffer` for the AMAX storage.""" def delete_key_from_amax_buffer(cls, forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
if forward: if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}" if (
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}" cls.buffer_delete_key_fwd is not None
and cls.buffer_delete_key_fwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_fwd]
else:
if (
cls.buffer_delete_key_bwd is not None
and cls.buffer_delete_key_bwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_bwd]
@classmethod
def get_fp8_context_id(cls) -> int:
"""Returns an ID for the current FP8 context."""
return cls.FP8_CURRENT_CONTEXT_ID
def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> None: @classmethod
"""Append 1D tensor `amax` to global buffer.""" def set_fp8_context_id(cls, ctx_id: int) -> None:
global _global_fp8_buffer """Sets the current FP8 context."""
buffer_key = get_amax_buffer_key(fp8_meta, forward=forward) cls.FP8_CURRENT_CONTEXT_ID = ctx_id
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward) @classmethod
def new_fp8_context_id(cls) -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return cls.FP8_AUTOCAST_COUNTER
@classmethod
def is_fp8_enabled(cls) -> bool:
"""Is FP8 enabled"""
return cls.FP8_ENABLED
@classmethod
def is_fp8_calibration(cls) -> bool:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp = cls.IS_FIRST_FP8_MODULE
cls.IS_FIRST_FP8_MODULE = False
return tmp
@classmethod
def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe"""
return cls.FP8_RECIPE
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return cls.FP8_DISTRIBUTED_GROUP
@classmethod
def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]:
"""FP8 autocast state getter"""
return (
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE)
@classmethod
def set_fp8_autocast_state(
cls,
fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
) -> None:
"""FP8 autocast state setter"""
(cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
)
return wait_handle
return None
@classmethod
def global_amax_reduction(
cls,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in cls.global_fp8_buffer:
return None
if buffer_key not in _global_fp8_buffer: # Reduce AMAX in DP-domain at an interval.
_global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] if cls.dp_amax_reduce_interval is None:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if cls.dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else: else:
_global_fp8_buffer[buffer_key].append( tp_amax_reduce = True
fp8_meta[fp8_meta_tensor_key].amax_history[0] cls.dp_amax_reduce_forward_idx = (
(cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval)
else:
if cls.dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_backward_idx = (
(cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval)
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key])
wait_handle = cls.reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
) )
if buffer_position_key not in fp8_meta: cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1 return wait_handle
# Catch incorrect fp8_autocast usage. @classmethod
assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \ def fp8_autocast_enter(
"Same module is being invoked more than once inside an `fp8_autocast` region when using " \ cls,
"FP8 with amax reduction. This behavior is currently unsupported. For more details and " \ enabled: bool = False,
"correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93." calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_COUNTER += 1
cls.FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
@classmethod
def fp8_autocast_exit(cls):
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None: @classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase """Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same. to ensure both forward steps are numerically same.
""" """
global _fp8_tensors_recompute_buffer
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = [ to_copy = [
...@@ -152,17 +410,17 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non ...@@ -152,17 +410,17 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non
] ]
if buffer_position_key in fp8_meta: if buffer_position_key in fp8_meta:
_fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else: else:
if len(_fp8_tensors_recompute_buffer) == 0: if len(cls.fp8_tensors_recompute_buffer) == 0:
_fp8_tensors_recompute_buffer = [deque()] cls.fp8_tensors_recompute_buffer = [deque()]
else: else:
_fp8_tensors_recompute_buffer.append(deque()) cls.fp8_tensors_recompute_buffer.append(deque())
_fp8_tensors_recompute_buffer[-1].append(to_copy) cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(_fp8_tensors_recompute_buffer) - 1 fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1
@classmethod
def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None: def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase """Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs. 1 forward for indentical numerical outputs.
""" """
...@@ -174,7 +432,7 @@ def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None: ...@@ -174,7 +432,7 @@ def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
# Retrieve stashed amaxes and scales from phase 1 pre forward. # Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = _fp8_tensors_recompute_buffer[ stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[
fp8_meta[buffer_position_key] fp8_meta[buffer_position_key]
].popleft() ].popleft()
...@@ -183,51 +441,14 @@ def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None: ...@@ -183,51 +441,14 @@ def get_old_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> None:
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run.""" """Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
def copy_amax_from_global_buffer(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in _global_fp8_buffer, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
def set_amax_buffer_key_deletion(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if get_autocast_key(forward=forward) not in fp8_meta:
return
global _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
_buffer_delete_key_fwd = get_amax_buffer_key(fp8_meta, forward=forward)
else:
_buffer_delete_key_bwd = get_amax_buffer_key(fp8_meta, forward=forward)
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
@contextmanager @contextmanager
def fp8_autocast( def fp8_autocast(
enabled: bool = False, enabled: bool = False,
...@@ -272,96 +493,16 @@ def fp8_autocast( ...@@ -272,96 +493,16 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (
_FP8_ENABLED,
_FP8_CALIBRATION,
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE)
try: try:
_FP8_ENABLED = enabled fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
_FP8_CALIBRATION = calibrating FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group)
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_FP8_DISTRIBUTED_GROUP = fp8_group
if _FP8_AUTOCAST_DEPTH == 0:
_IS_FIRST_FP8_MODULE = True
_FP8_AUTOCAST_COUNTER += 1
_FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
yield yield
finally: finally:
(_FP8_ENABLED, FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
_FP8_CALIBRATION, FP8GlobalStateManager.fp8_autocast_exit()
_FP8_RECIPE,
_FP8_DISTRIBUTED_GROUP,
_IS_FIRST_FP8_MODULE) = fp8_state
_FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_reduce_handle_fwd = _amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
def get_fp8_context_id() -> int:
"""Returns an ID for the current FP8 context."""
return _FP8_CURRENT_CONTEXT_ID
def set_fp8_context_id(ctx_id: int) -> None:
"""Sets the current FP8 context."""
global _FP8_CURRENT_CONTEXT_ID
_FP8_CURRENT_CONTEXT_ID = ctx_id
def new_fp8_context_id() -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return _FP8_AUTOCAST_COUNTER
def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
global _IS_FIRST_FP8_MODULE
tmp = _IS_FIRST_FP8_MODULE
_IS_FIRST_FP8_MODULE = False
return tmp
def get_fp8_recipe() -> DelayedScaling:
"""Return the fp8 recipe"""
return _FP8_RECIPE
def get_fp8_group() -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return _FP8_DISTRIBUTED_GROUP
def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero.""" """Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1: if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0) amax_history = torch.roll(amax_history, -1, 0)
...@@ -380,7 +521,7 @@ def _default_get_amax( ...@@ -380,7 +521,7 @@ def _default_get_amax(
else: # amax_compute_algo == "most_recent" else: # amax_compute_algo == "most_recent"
amax = amax_history[0].clone() amax = amax_history[0].clone()
amax_history = update_amax_history(amax_history) amax_history = _update_amax_history(amax_history)
return amax_history, amax return amax_history, amax
...@@ -415,7 +556,7 @@ def _compute_scaling_factor_inverse( ...@@ -415,7 +556,7 @@ def _compute_scaling_factor_inverse(
@jit_fuser @jit_fuser
def fused_amax_and_scale_update( def _fused_amax_and_scale_update(
amax_history: torch.Tensor, amax_history: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
scale_inv: torch.Tensor, scale_inv: torch.Tensor,
...@@ -460,7 +601,7 @@ def _compute_amax( ...@@ -460,7 +601,7 @@ def _compute_amax(
if callable(recipe.amax_compute_algo): if callable(recipe.amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history) amax = recipe.amax_compute_algo(amax_history)
amax_history = update_amax_history(amax_history) amax_history = _update_amax_history(amax_history)
return amax_history, amax return amax_history, amax
return _default_get_amax( return _default_get_amax(
amax_history, amax_history,
...@@ -502,7 +643,7 @@ def amax_and_scale_update( ...@@ -502,7 +643,7 @@ def amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key].scale_inv,
) = fused_amax_and_scale_update( ) = _fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key].scale_inv,
...@@ -529,99 +670,3 @@ def amax_and_scale_update( ...@@ -529,99 +670,3 @@ def amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv, update_weight_scale_inv,
) )
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
)
return wait_handle
return None
def global_amax_reduction(
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
global _global_fp8_buffer
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in _global_fp8_buffer:
return None
# Reduce AMAX in DP-domain at an interval.
global _dp_amax_reduce_interval, _dp_amax_reduce_forward_idx, _dp_amax_reduce_backward_idx
if _dp_amax_reduce_interval is None:
_dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if _dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_forward_idx = (_dp_amax_reduce_forward_idx + 1) % _dp_amax_reduce_interval
else:
if _dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
_dp_amax_reduce_backward_idx = (_dp_amax_reduce_backward_idx + 1) % _dp_amax_reduce_interval
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
def delete_key_from_amax_buffer(forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
global _global_fp8_buffer, _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
if (
_buffer_delete_key_fwd is not None
and _buffer_delete_key_fwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_fwd]
else:
if (
_buffer_delete_key_bwd is not None
and _buffer_delete_key_bwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_bwd]
...@@ -19,29 +19,10 @@ from torch.nn.parameter import Parameter ...@@ -19,29 +19,10 @@ from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..fp8 import ( from ..fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe, get_default_fp8_recipe,
get_fp8_te_dtype, get_fp8_te_dtype,
is_first_fp8_module, FP8GlobalStateManager,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
amax_and_scale_update, amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
) )
from ..distributed import ( from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
...@@ -97,30 +78,30 @@ def _prepare_backward( ...@@ -97,30 +78,30 @@ def _prepare_backward(
# Update amax and scale; Skip all setup for global amax reduction # Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax: if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False) FP8GlobalStateManager.amax_and_scale_update(fp8_meta, False)
else: else:
# From previous iteration # From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False) FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key. # Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"): with torch.cuda.nvtx.range(name + " backward"):
yield yield
if fp8 and fp8_meta["recipe"].reduce_amax: if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction( _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta, fp8_meta,
tp_group, tp_group,
tp_size, tp_size,
forward=False forward=False
) )
delete_key_from_amax_buffer(forward=False) FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub( def initialize_ub(
...@@ -356,12 +337,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -356,12 +337,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer() state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)): if isinstance(v, (bool, int, float, str, list)):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -403,7 +385,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -403,7 +385,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd) self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state. # Restore global FP8 buffer state.
set_global_fp8_buffer(state[4]) FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5] self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6] self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7] self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
...@@ -420,8 +402,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -420,8 +402,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Restore global FP8 buffer states. # Restore global FP8 amax buffer.
set_global_fp8_buffer(state["global_fp8_buffer"]) FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
if "global_fp8_state" in state:
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
else:
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
...@@ -525,19 +515,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -525,19 +515,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled() self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: if (self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]):
return return
# Set FP8, recipe, and other FP8 metadata # Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group() self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
...@@ -567,7 +558,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -567,7 +558,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else: else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -591,11 +582,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -591,11 +582,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
) )
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else: else:
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
...@@ -604,20 +595,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -604,20 +595,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish # Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd() amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None: if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait() amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = (
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append( self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"] self.fp8_meta["autocast_id_fwd"]
) )
add_amax_to_global_buffer(self.fp8_meta, forward=True) FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.fp8_meta["update_amax_and_scale_fwd"] = False
...@@ -629,25 +622,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -629,25 +622,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous() yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
global_amax_reduction, FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta, self.fp8_meta,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
forward=True forward=True
) )
setup_amax_forward_global_reduce_func(reduce_func) FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment