Unverified Commit 02a3582c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move calculation of scale inverse to framework (#51)



* Move scale inverse calculation to framework
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix gated kernel/geglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40467fc2
...@@ -171,6 +171,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -171,6 +171,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda", device="cuda",
) )
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False]: -> [input, weight]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False] * self.fp8_meta["num_gemms"]
).cuda()
else:
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None: def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes.""" """Init scales and amaxes."""
# Checkpoint loaded # Checkpoint loaded
...@@ -360,7 +372,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -360,7 +372,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
@contextmanager @contextmanager
def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1): def prepare_forward(
self,
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> None:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful if it's the last FP8 module in the forward autocast. It is useful
...@@ -368,7 +385,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -368,7 +385,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
...@@ -382,14 +398,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -382,14 +398,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else: else:
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training: if self.fp8 and self.training:
# Setup for amax reduction # Setup for amax reduction
...@@ -1085,7 +1107,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1085,7 +1107,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
with self.prepare_forward(inp) as inp: # pylint with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply( out = _LayerNormLinear.apply(
...@@ -1604,7 +1626,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1604,7 +1626,7 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
with self.prepare_forward(inp) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply( out = _Linear.apply(
...@@ -2400,7 +2422,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2400,7 +2422,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
with self.prepare_forward(inp, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
out = _LayerNormMLP.apply( out = _LayerNormMLP.apply(
inp, inp,
self.layer_norm_weight, self.layer_norm_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment