Unverified Commit 28aa41a3 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure (#1326)



* Remove manual FP8 scale update for FP8 params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c0a539c6
...@@ -293,7 +293,7 @@ class TestFuser: ...@@ -293,7 +293,7 @@ class TestFuser:
) )
# Check that scaling factors match expected # Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2]) w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
......
...@@ -109,8 +109,6 @@ class FP8GlobalStateManager: ...@@ -109,8 +109,6 @@ class FP8GlobalStateManager:
cls.fp8_available = None cls.fp8_available = None
cls.reason_for_no_fp8 = "" cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {} cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None cls.skip_fp8_weight_update_tensor = None
@classmethod @classmethod
...@@ -156,28 +154,25 @@ class FP8GlobalStateManager: ...@@ -156,28 +154,25 @@ class FP8GlobalStateManager:
def get_key_in_buffer( def get_key_in_buffer(
cls, cls,
forward: bool, forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling, fp8_recipe: DelayedScaling,
fp8_group: dist_group_type, fp8_group: dist_group_type,
) -> str: ) -> str:
"""Returns a key into the global FP8 buffers.""" """Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward) fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" return f"{fwd_bwd_key}_{autocast_key}"
@classmethod @classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts.""" """Splits buffer key into relevant parts."""
forward, fp8_weights, autocast_key = key.split("_", 2) forward, autocast_key = key.split("_", 1)
forward = forward == "forward" forward = forward == "forward"
fp8_weights = fp8_weights == "True" return forward, autocast_key
return forward, fp8_weights, autocast_key
@classmethod @classmethod
def add_fp8_tensors_to_global_buffer( def add_fp8_tensors_to_global_buffer(
cls, cls,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None: ) -> None:
""" """
The amax reduction process happens completely outside the FP8 modules. The amax reduction process happens completely outside the FP8 modules.
...@@ -202,33 +197,12 @@ class FP8GlobalStateManager: ...@@ -202,33 +197,12 @@ class FP8GlobalStateManager:
fp8_meta[index_in_buffer] = [] fp8_meta[index_in_buffer] = []
for forward in (True, False): for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta: if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA. # Handles non-parameter FP8 modules, e.g. DPA.
continue continue
if forward and fp8_weights is not None: key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]
)
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else:
cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[
autocast_key
].union(fp8_weight_set)
# Identify correct autocast key for a given param.
for w in fp8_weight_set:
cls.fp8_param_to_autocast[w] = autocast_key
key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]
)
if key not in cls.global_amax_buffer: if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
...@@ -327,20 +301,13 @@ class FP8GlobalStateManager: ...@@ -327,20 +301,13 @@ class FP8GlobalStateManager:
def reduce_and_update_fp8_tensors( def reduce_and_update_fp8_tensors(
cls, cls,
forward: bool = True, forward: bool = True,
fp8_weights: bool = False,
) -> None: ) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer.""" """Concatenate, reduce, and split amaxes in the global buffer."""
for buffer_key, amax_buffer in cls.global_amax_buffer.items(): for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction. # Check for forward or backward reduction.
fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward: if fwd_update != forward:
continue continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# (inside optimizer) and the current key is not an `fp8_weight_update` key.
# For other cases, we need to reduce because of activation tensors.
# TODO(ksivaman) consider separate weight and activation fp8_tensors.
if fwd_update and fp8_weights and not fp8_weights_update:
continue
if len(amax_buffer) == 0: if len(amax_buffer) == 0:
continue continue
...@@ -434,7 +401,7 @@ class FP8GlobalStateManager: ...@@ -434,7 +401,7 @@ class FP8GlobalStateManager:
# FP8 weight modules are reduced at the end of the optimizer # FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated. # step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) cls.reduce_and_update_fp8_tensors(forward=True)
@classmethod @classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
......
...@@ -465,7 +465,7 @@ def _make_graphed_callables( ...@@ -465,7 +465,7 @@ def _make_graphed_callables(
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params() m.fp8_meta,
) )
return graphed(*user_args, **user_kwargs) return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs)
......
...@@ -762,9 +762,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -762,9 +762,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
self.fp8_meta, fp8_weights=self._get_fp8_params()
)
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
......
...@@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager, FP8GlobalStateManager,
get_default_fp8_recipe, get_default_fp8_recipe,
) )
from ._common import canonicalize_device, is_float8_tensor from ._common import canonicalize_device
@dataclasses.dataclass @dataclasses.dataclass
...@@ -379,10 +379,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -379,10 +379,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self.get_fp8_meta("input"), self.get_fp8_meta("input"),
) )
if self.num_fp8_scales("param"): if self.num_fp8_scales("param"):
fp8_params = list(filter(is_float8_tensor, self.parameters()))
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("param"), self.get_fp8_meta("param"),
fp8_weights=(fp8_params if fp8_params else None),
) )
if self.num_fp8_scales("grad_output"): if self.num_fp8_scales("grad_output"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
......
...@@ -74,30 +74,6 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -74,30 +74,6 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function): class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
...@@ -676,9 +652,6 @@ class Float8Tensor(QuantizedTensor): ...@@ -676,9 +652,6 @@ class Float8Tensor(QuantizedTensor):
) )
dst._transpose_invalid = False dst._transpose_invalid = False
# Callback hook to perform amax reduction after optimizer step
post_optimizer_step_fwd_amax_reduction(self)
return self return self
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment