Unverified Commit 7fc079a4 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Schetlur/fp8 calibration (#40)



* Initial commit for fp8 calibration.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixes to make unit tests pass
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Added test and finished implementation
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Cleaning up handling of save_for_backward in Linear
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Removing commented lines
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Minor fix to mnist test.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Moving stats computation to the forward pass instead of pre_forward, and extending to all other layers
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Fixing unit test failures.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Misc changes
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Fixing bad indentation from master merge and moving some code into the needs_stats conditional
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarschetlur-nv <116769508+schetlur-nv@users.noreply.github.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
parent 275902fd
......@@ -68,6 +68,17 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
break
def calibrate(model, device, test_loader):
"""Calibration function."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=False, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
......@@ -156,7 +167,10 @@ def main():
help="For Saving the current Model",
)
parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 training"
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
......@@ -164,10 +178,13 @@ def main():
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
if args.use_fp8:
if args.use_fp8 or args.use_fp8_infer:
assert use_cuda, "CUDA needed for FP8 execution."
args.use_te = True
if args.use_fp8_infer:
assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint"
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
......@@ -196,8 +213,15 @@ def main():
test(model, device, test_loader, args.use_fp8)
scheduler.step()
if args.save_model:
if args.use_fp8_infer:
calibrate(model, device, test_loader)
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
if __name__ == "__main__":
......
......@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
_FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False
......@@ -201,6 +202,7 @@ def get_default_fp8_recipe() -> DelayedScaling:
@contextmanager
def fp8_autocast(
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
......@@ -229,12 +231,13 @@ def fp8_autocast(
are reduced at the end of each training step.
"""
global _FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
fp8_state = (_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try:
_FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_FP8_DISTRIBUTED_GROUP = fp8_group
......@@ -249,7 +252,7 @@ def fp8_autocast(
), "Device compute capability 9.x required for FP8 execution."
yield
finally:
_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_DEPTH -= 1
......@@ -281,6 +284,9 @@ def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
......
......@@ -19,6 +19,7 @@ from torch.nn import init
import transformer_engine_extensions as tex
from .fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe,
......@@ -142,7 +143,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe()
......@@ -199,7 +202,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
if self.fp8:
if self.fp8 or self.fp8_calibration:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
......@@ -359,19 +362,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group
self.tp_group_initialized = True
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
# If fp8 isn't enabled, turn off and return.
if not is_fp8_enabled():
self.fp8 = False
return
# FP8 is already enabled and recipe is the same, don't do anything.
if self.fp8 and get_fp8_recipe() == self.fp8_meta["recipe"]:
if is_fp8_enabled() or is_fp8_calibration():
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8 = True
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
......@@ -382,6 +384,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
self.fp8_initialized = True
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager
def prepare_forward(
......@@ -410,8 +417,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Either we're in FP8 training or calibration for FP8 inference
needs_stats = (self.training if self.fp8 else self.fp8_calibration)
if needs_stats:
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
......@@ -424,8 +434,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
......@@ -589,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function):
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
......@@ -721,6 +730,14 @@ class _LayerNormLinear(torch.autograd.Function):
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
weight,
ln_out_total,
......@@ -1195,6 +1212,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
......@@ -1239,6 +1257,7 @@ class _Linear(torch.autograd.Function):
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
......@@ -1338,6 +1357,14 @@ class _Linear(torch.autograd.Function):
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
weight,
inputmat_total,
......@@ -1348,15 +1375,16 @@ class _Linear(torch.autograd.Function):
)
if is_training:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
inputmat_no_fp8
if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad
else None,
inputmat_t
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
inputmat_t if weight.requires_grad and fp8_wgrad
else None,
weight if inputmat.requires_grad and not fp8
else None,
weight_t_fp8 if inputmat.requires_grad and fp8
else None,
weight,
weight_t_fp8,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
......@@ -1370,6 +1398,8 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.requires_dgrad = inputmat.requires_grad
ctx.requires_wgrad = weight.requires_grad
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -1434,7 +1464,9 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.requires_dgrad:
# DGRAD
if ctx.fp8:
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
......@@ -1468,7 +1500,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.requires_wgrad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -1523,10 +1555,10 @@ class _Linear(torch.autograd.Function):
grad_bias = None
return (
wgrad if weight.requires_grad else None,
wgrad if ctx.requires_wgrad else None,
None,
None,
dgrad.view(ctx.inp_shape),
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
None,
......@@ -1539,6 +1571,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1745,6 +1778,7 @@ class Linear(TransformerEngineBaseModule):
self.use_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
......@@ -1787,6 +1821,7 @@ class _LayerNormMLP(torch.autograd.Function):
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
......@@ -1957,6 +1992,14 @@ class _LayerNormMLP(torch.autograd.Function):
cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias
)
if fp8_calibration:
# amax of fc1 input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of fc1 weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(fc1_weight).float()
fc1_outputs = gemm(
fc1_weight,
ln_out_total,
......@@ -1973,6 +2016,14 @@ class _LayerNormMLP(torch.autograd.Function):
else:
gelu_out, _, fc1_out = fc1_outputs
if fp8_calibration:
# amax of fc2 input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
torch.amax(gelu_out).float()
# amax of fc2 weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float()
fc2_out, _, _ = gemm(
fc2_weight,
gelu_out,
......@@ -2612,6 +2663,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
......@@ -2761,6 +2813,7 @@ class LayerNorm(torch.nn.Module):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment