"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "a1cc8c55195661a58ad60c3bb062a0b9c302710d"
Unverified Commit 7fc079a4 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Schetlur/fp8 calibration (#40)



* Initial commit for fp8 calibration.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixes to make unit tests pass
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Added test and finished implementation
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Cleaning up handling of save_for_backward in Linear
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Removing commented lines
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Minor fix to mnist test.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Moving stats computation to the forward pass instead of pre_forward, and extending to all other layers
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Pylint cleanup.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Fixing unit test failures.
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Misc changes
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>

* Fixing bad indentation from master merge and moving some code into the needs_stats conditional
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarschetlur-nv <116769508+schetlur-nv@users.noreply.github.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
parent 275902fd
...@@ -68,6 +68,17 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): ...@@ -68,6 +68,17 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
break break
def calibrate(model, device, test_loader):
"""Calibration function."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=False, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -156,7 +167,10 @@ def main(): ...@@ -156,7 +167,10 @@ def main():
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 training" "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
) )
parser.add_argument( parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine" "--use-te", action="store_true", default=False, help="Use Transformer Engine"
...@@ -164,10 +178,13 @@ def main(): ...@@ -164,10 +178,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
if args.use_fp8: if args.use_fp8 or args.use_fp8_infer:
assert use_cuda, "CUDA needed for FP8 execution." assert use_cuda, "CUDA needed for FP8 execution."
args.use_te = True args.use_te = True
if args.use_fp8_infer:
assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint"
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
...@@ -196,8 +213,15 @@ def main(): ...@@ -196,8 +213,15 @@ def main():
test(model, device, test_loader, args.use_fp8) test(model, device, test_loader, args.use_fp8)
scheduler.step() scheduler.step()
if args.save_model: if args.use_fp8_infer:
calibrate(model, device, test_loader)
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format ...@@ -14,6 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
_FP8_ENABLED = False _FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None _FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None _FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False _IS_FIRST_FP8_MODULE = False
...@@ -201,6 +202,7 @@ def get_default_fp8_recipe() -> DelayedScaling: ...@@ -201,6 +202,7 @@ def get_default_fp8_recipe() -> DelayedScaling:
@contextmanager @contextmanager
def fp8_autocast( def fp8_autocast(
enabled: bool = False, enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
) -> None: ) -> None:
...@@ -229,12 +231,13 @@ def fp8_autocast( ...@@ -229,12 +231,13 @@ def fp8_autocast(
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
global _FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd global _global_fp8_buffer, _buffer_delete_key_fwd
fp8_state = (_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP) fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try: try:
_FP8_ENABLED = enabled _FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe _FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_FP8_DISTRIBUTED_GROUP = fp8_group _FP8_DISTRIBUTED_GROUP = fp8_group
...@@ -249,7 +252,7 @@ def fp8_autocast( ...@@ -249,7 +252,7 @@ def fp8_autocast(
), "Device compute capability 9.x required for FP8 execution." ), "Device compute capability 9.x required for FP8 execution."
yield yield
finally: finally:
_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state _FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_IS_FIRST_FP8_MODULE = False _IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_DEPTH -= 1 _FP8_AUTOCAST_DEPTH -= 1
...@@ -281,6 +284,9 @@ def is_fp8_enabled() -> bool: ...@@ -281,6 +284,9 @@ def is_fp8_enabled() -> bool:
"""Is FP8 enabled""" """Is FP8 enabled"""
return _FP8_ENABLED return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_first_fp8_module(): def is_first_fp8_module():
"""Returns `True` only the first time when called multiple """Returns `True` only the first time when called multiple
......
...@@ -19,6 +19,7 @@ from torch.nn import init ...@@ -19,6 +19,7 @@ from torch.nn import init
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from .fp8 import ( from .fp8 import (
is_fp8_enabled, is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe, get_fp8_recipe,
get_fp8_group, get_fp8_group,
get_default_fp8_recipe, get_default_fp8_recipe,
...@@ -142,7 +143,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -142,7 +143,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8 = False self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {} self.fp8_meta = {}
self.fp8_meta["fp8_group"] = None self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta["recipe"] = get_default_fp8_recipe()
...@@ -199,7 +202,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -199,7 +202,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None state = None
if self.fp8: if self.fp8 or self.fp8_calibration:
state = {} state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
...@@ -359,29 +362,33 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -359,29 +362,33 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group self.tp_group = tp_group
self.tp_group_initialized = True self.tp_group_initialized = True
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
# If fp8 isn't enabled, turn off and return. if is_fp8_enabled() or is_fp8_calibration():
if not is_fp8_enabled(): # FP8 init has already been run and recipe is the same, don't do anything.
self.fp8 = False if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return return
# FP8 is already enabled and recipe is the same, don't do anything.
if self.fp8 and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata # Set FP8, recipe, and other FP8 metadata
self.fp8 = True self.fp8 = is_fp8_enabled()
self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["fp8_group"] = get_fp8_group() self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
self.fp8_initialized = True
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
...@@ -410,22 +417,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -410,22 +417,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch # Either we're in FP8 training or calibration for FP8 inference
needs_stats = (self.training if self.fp8 else self.fp8_calibration)
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if needs_stats:
if self.fp8_meta["recipe"].reduce_amax: update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
copy_amax_from_global_buffer(self.fp8_meta, forward=True) # Previous iteration was grad_enabled
amax_and_scale_update( if self.fp8_meta.get("update_amax_and_scale_fwd", False):
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv if self.fp8_meta["recipe"].reduce_amax:
) copy_amax_from_global_buffer(self.fp8_meta, forward=True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) amax_and_scale_update(
else: self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
amax_and_scale_update( )
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
) else:
amax_and_scale_update(
if self.fp8 and self.training: self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
# Setup for amax reduction # Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
...@@ -438,17 +446,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -438,17 +446,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["autocast_id_fwd"] self.fp8_meta["autocast_id_fwd"]
) )
add_amax_to_global_buffer(self.fp8_meta, forward=True) add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if ( if (
self.fp8 self.fp8
and is_fp8_activation_recompute_enabled() and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous() yield inp.contiguous()
...@@ -589,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -589,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function):
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
...@@ -721,6 +730,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -721,6 +730,14 @@ class _LayerNormLinear(torch.autograd.Function):
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm( out, _, _ = gemm(
weight, weight,
ln_out_total, ln_out_total,
...@@ -1195,6 +1212,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1195,6 +1212,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
...@@ -1239,6 +1257,7 @@ class _Linear(torch.autograd.Function): ...@@ -1239,6 +1257,7 @@ class _Linear(torch.autograd.Function):
use_bias: bool, use_bias: bool,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
...@@ -1338,6 +1357,14 @@ class _Linear(torch.autograd.Function): ...@@ -1338,6 +1357,14 @@ class _Linear(torch.autograd.Function):
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm( out, _, _ = gemm(
weight, weight,
inputmat_total, inputmat_total,
...@@ -1348,15 +1375,16 @@ class _Linear(torch.autograd.Function): ...@@ -1348,15 +1375,16 @@ class _Linear(torch.autograd.Function):
) )
if is_training: if is_training:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward( ctx.save_for_backward(
inputmat_no_fp8 inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad
if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad
else None, else None,
inputmat_t inputmat_t if weight.requires_grad and fp8_wgrad
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad else None,
weight if inputmat.requires_grad and not fp8
else None,
weight_t_fp8 if inputmat.requires_grad and fp8
else None, else None,
weight,
weight_t_fp8,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -1370,6 +1398,8 @@ class _Linear(torch.autograd.Function): ...@@ -1370,6 +1398,8 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.requires_dgrad = inputmat.requires_grad
ctx.requires_wgrad = weight.requires_grad
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -1434,30 +1464,32 @@ class _Linear(torch.autograd.Function): ...@@ -1434,30 +1464,32 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.requires_dgrad:
# DGRAD # DGRAD
dgrad = fp8_gemm( if ctx.fp8:
weight_t_fp8, dgrad = fp8_gemm(
fwd_scale_inverses, weight_t_fp8,
tex.FP8FwdTensors.GEMM1_WEIGHT, fwd_scale_inverses,
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_WEIGHT,
grad_output_c, fp8_dtype_forward,
ctx.fp8_meta["scaling_bwd"].scale_inv, grad_output_c,
tex.FP8BwdTensors.GRAD_OUTPUT1, ctx.fp8_meta["scaling_bwd"].scale_inv,
fp8_dtype_backward, tex.FP8BwdTensors.GRAD_OUTPUT1,
ctx.activation_dtype, fp8_dtype_backward,
get_workspace(), ctx.activation_dtype,
use_split_accumulator=_2X_ACC_DGRAD, get_workspace(),
) use_split_accumulator=_2X_ACC_DGRAD,
else: )
# DGRAD else:
dgrad, _, _ = gemm( # DGRAD
weight, dgrad, _, _ = gemm(
grad_output, weight,
ctx.activation_dtype, grad_output,
get_workspace(), ctx.activation_dtype,
layout="NN", get_workspace(),
grad=True, layout="NN",
) grad=True,
)
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
...@@ -1468,7 +1500,7 @@ class _Linear(torch.autograd.Function): ...@@ -1468,7 +1500,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad: if ctx.requires_wgrad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -1523,10 +1555,10 @@ class _Linear(torch.autograd.Function): ...@@ -1523,10 +1555,10 @@ class _Linear(torch.autograd.Function):
grad_bias = None grad_bias = None
return ( return (
wgrad if weight.requires_grad else None, wgrad if ctx.requires_wgrad else None,
None, None,
None, None,
dgrad.view(ctx.inp_shape), dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias, grad_bias,
None, None,
None, None,
...@@ -1539,6 +1571,7 @@ class _Linear(torch.autograd.Function): ...@@ -1539,6 +1571,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1745,6 +1778,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1745,6 +1778,7 @@ class Linear(TransformerEngineBaseModule):
self.use_bias, self.use_bias,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
...@@ -1787,6 +1821,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1787,6 +1821,7 @@ class _LayerNormMLP(torch.autograd.Function):
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
...@@ -1957,6 +1992,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1957,6 +1992,14 @@ class _LayerNormMLP(torch.autograd.Function):
cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias
) )
if fp8_calibration:
# amax of fc1 input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of fc1 weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(fc1_weight).float()
fc1_outputs = gemm( fc1_outputs = gemm(
fc1_weight, fc1_weight,
ln_out_total, ln_out_total,
...@@ -1973,6 +2016,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1973,6 +2016,14 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
gelu_out, _, fc1_out = fc1_outputs gelu_out, _, fc1_out = fc1_outputs
if fp8_calibration:
# amax of fc2 input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
torch.amax(gelu_out).float()
# amax of fc2 weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float()
fc2_out, _, _ = gemm( fc2_out, _, _ = gemm(
fc2_weight, fc2_weight,
gelu_out, gelu_out,
...@@ -2612,6 +2663,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2612,6 +2663,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
...@@ -2761,6 +2813,7 @@ class LayerNorm(torch.nn.Module): ...@@ -2761,6 +2813,7 @@ class LayerNorm(torch.nn.Module):
init.ones_(self.weight) init.ones_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
# Maintain backward compatibility. # Maintain backward compatibility.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment