Unverified Commit 80542a0a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Change FP8 recipe defaults (#112)



* Change FP8 recipe defaults
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Increase default amax history length
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Always check history size
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* no amax history for onnx export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert onnx export test changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix indices in onnx test
Co-authored-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarNeta Zmora <nzmora@nvidia.com>
parent bcbd4be0
......@@ -105,12 +105,15 @@ def to_numpy(tensor):
return tensor.cpu().numpy()
def set_layer_scale(module: torch.nn.Module, scale: float):
module.fp8_init()
def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module"""
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.fp8_init(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones(
2, dtype=torch.float32, device="cuda") / scale
nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
2, dtype=torch.float32, device="cuda") * scale
nb_total_scales, dtype=torch.float32, device="cuda") * scale
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
......@@ -678,7 +681,7 @@ def test_export_linear(
precision
).to(device='cuda')
if use_fp8:
set_layer_scale(model.linear, scale_factor)
set_layer_scale(model.linear, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
if precision in (torch.bfloat16, ):
......@@ -736,7 +739,7 @@ def test_export_layernorm_linear(
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
set_layer_scale(model, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
......@@ -792,7 +795,7 @@ def test_export_layernorm_mlp(
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
......
......@@ -66,10 +66,10 @@ class DelayedScaling:
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
amax_history_len : int, default = 1
amax_history_len : int, default = 1024
The length of the amax history window used for
scaling factor computation.
amax_compute_algo : {'max', 'most_recent', Callable}, default = 'most_recent'
amax_compute_algo : {'max', 'most_recent', Callable}, default = 'max'
Algorithm used for choosing the `amax` value for the
scaling factor computation. There are 2 predefined
choices: `max` chooses the largest `amax` in the history
......@@ -125,8 +125,8 @@ class DelayedScaling:
margin: int = 0
interval: int = 1
fp8_format: Format = Format.HYBRID
amax_history_len: int = 1
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent"
amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True
......
......@@ -13,6 +13,7 @@ from contextlib import contextmanager
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn import init
......@@ -187,6 +188,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (
......@@ -222,12 +240,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
# Checkpoint loaded
if self.fp8_meta_tensors_initialized:
return
self.set_meta_tensor(True)
self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
......@@ -280,7 +295,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
self.fp8_meta_tensors_initialized = True
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
......@@ -310,7 +324,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta_tensors_initialized = True
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment