Unverified Commit 02a3582c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move calculation of scale inverse to framework (#51)



* Move scale inverse calculation to framework
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix gated kernel/geglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40467fc2
......@@ -171,6 +171,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda",
)
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False]: -> [input, weight]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False] * self.fp8_meta["num_gemms"]
).cuda()
else:
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
# Checkpoint loaded
......@@ -360,7 +372,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_meta_tensors()
@contextmanager
def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1):
def prepare_forward(
self,
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> None:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
......@@ -368,7 +385,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......@@ -382,14 +398,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(self.fp8_meta, True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
......@@ -1085,7 +1107,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
with self.prepare_forward(inp) as inp: # pylint
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
......@@ -1604,7 +1626,7 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
with self.prepare_forward(inp) as inp:
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
......@@ -2400,7 +2422,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
with self.prepare_forward(inp, num_gemms=2) as inp:
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
out = _LayerNormMLP.apply(
inp,
self.layer_norm_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment