Unverified Commit 2e0bfbd9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] FP8 fixes (#380)



* Initial refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reorder methods by purpose
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save full global state
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes to test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cbfb8c6b
...@@ -3,6 +3,7 @@ extension-pkg-whitelist=torch, ...@@ -3,6 +3,7 @@ extension-pkg-whitelist=torch,
transformer_engine_extensions transformer_engine_extensions
disable=too-many-locals, disable=too-many-locals,
too-many-public-methods,
invalid-name, invalid-name,
too-many-arguments, too-many-arguments,
abstract-method, abstract-method,
......
...@@ -10,7 +10,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -10,7 +10,7 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal, scaled_init_method_normal,
get_device_compute_capability, get_device_compute_capability,
) )
from transformer_engine.pytorch.fp8 import is_fp8_available from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
import os import os
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
from pkg_resources import packaging from pkg_resources import packaging
from importlib.metadata import version from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
......
...@@ -38,7 +38,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp ...@@ -38,7 +38,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Global test configuration knobs. # Global test configuration knobs.
...@@ -66,7 +66,7 @@ assert OPSET >= TRILU_OPSET ...@@ -66,7 +66,7 @@ assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import pytest import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch import ( ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def custom_amax_to_scale( def custom_amax_to_scale(
......
...@@ -12,7 +12,7 @@ from torch.utils.checkpoint import detach_variable ...@@ -12,7 +12,7 @@ from torch.utils.checkpoint import detach_variable
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import is_fp8_enabled from .fp8 import FP8GlobalStateManager
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False, "tensor_model_parallel": False,
...@@ -145,7 +145,8 @@ def activation_recompute_forward( ...@@ -145,7 +145,8 @@ def activation_recompute_forward(
""" """
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try: try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled() _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase _FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield yield
finally: finally:
......
This diff is collapsed.
...@@ -19,29 +19,10 @@ from torch.nn.parameter import Parameter ...@@ -19,29 +19,10 @@ from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..fp8 import ( from ..fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe, get_default_fp8_recipe,
get_fp8_te_dtype, get_fp8_te_dtype,
is_first_fp8_module, FP8GlobalStateManager,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
amax_and_scale_update, amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
) )
from ..distributed import ( from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
...@@ -97,30 +78,30 @@ def _prepare_backward( ...@@ -97,30 +78,30 @@ def _prepare_backward(
# Update amax and scale; Skip all setup for global amax reduction # Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax: if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False) FP8GlobalStateManager.amax_and_scale_update(fp8_meta, False)
else: else:
# From previous iteration # From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False) FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key. # Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False) FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"): with torch.cuda.nvtx.range(name + " backward"):
yield yield
if fp8 and fp8_meta["recipe"].reduce_amax: if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction( _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta, fp8_meta,
tp_group, tp_group,
tp_size, tp_size,
forward=False forward=False
) )
delete_key_from_amax_buffer(forward=False) FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub( def initialize_ub(
...@@ -356,12 +337,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -356,12 +337,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer() state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)): if isinstance(v, (bool, int, float, str, list)):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -403,7 +385,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -403,7 +385,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd) self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state. # Restore global FP8 buffer state.
set_global_fp8_buffer(state[4]) FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5] self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6] self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7] self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
...@@ -420,8 +402,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -420,8 +402,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Restore global FP8 buffer states. # Restore global FP8 amax buffer.
set_global_fp8_buffer(state["global_fp8_buffer"]) FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
if "global_fp8_state" in state:
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
else:
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
...@@ -525,19 +515,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -525,19 +515,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled() self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: if (self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]):
return return
# Set FP8, recipe, and other FP8 metadata # Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group() self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
...@@ -567,7 +558,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -567,7 +558,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else: else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -591,11 +582,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -591,11 +582,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
) )
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else: else:
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
...@@ -604,20 +595,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -604,20 +595,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish # Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd() amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None: if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait() amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = (
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append( self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"] self.fp8_meta["autocast_id_fwd"]
) )
add_amax_to_global_buffer(self.fp8_meta, forward=True) FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.fp8_meta["update_amax_and_scale_fwd"] = False
...@@ -629,25 +622,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -629,25 +622,25 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous() yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial( reduce_func = partial(
global_amax_reduction, FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta, self.fp8_meta,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
forward=True forward=True
) )
setup_amax_forward_global_reduce_func(reduce_func) FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment