Unverified Commit 8641ab77 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Do not allocate FP8 workspace buffers when params are FP8 (#647)



Do not allocate FP8 workspace buffers when params are FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent b5e13a16
...@@ -392,15 +392,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -392,15 +392,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.activation_dtype = dtype self.activation_dtype = dtype
def set_fp8_weights(self) -> None: def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These """Construct workspace buffers for FP8 weights, if needed
are not parameters or buffers since we do not want functions such as
`.to(dtype)` or `.to(device)` to effect them. These also do not need These workspace buffers are used for FP8 training when the
to be checkpointed. During `init` phase of the module, the attribute module parameters are not natively in FP8 and there are
`fp8_weight_shapes` must be populated with the tensor shapes for FP8 multiple microbatches per training step. The buffers, with
weights. This function will iterate over those shapes and initialize names like `weight1_fp8` and `weight1_t_fp8`, cache the FP8
respective attributed named `weight1_fp8`, `weight2_fp8`, ... values and transposed FP8 values in between microbatches. They
are not registered as module parameters or buffers since we
don't want them to be affected by `.to` and since they aren't
needed for checkpointing.
""" """
if not self.fp8: if not self.fp8 or self.primary_weights_in_fp8:
return return
for i, shape in enumerate(self.fp8_weight_shapes, start=1): for i, shape in enumerate(self.fp8_weight_shapes, start=1):
...@@ -517,8 +521,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -517,8 +521,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes # Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used. # only when fp8 weight caching is used and weights are not in fp8
if is_first_microbatch is not None: if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment