Unverified Commit 8641ab77 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Do not allocate FP8 workspace buffers when params are FP8 (#647)



Do not allocate FP8 workspace buffers when params are FP8
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent b5e13a16
......@@ -392,15 +392,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.activation_dtype = dtype
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These
are not parameters or buffers since we do not want functions such as
`.to(dtype)` or `.to(device)` to effect them. These also do not need
to be checkpointed. During `init` phase of the module, the attribute
`fp8_weight_shapes` must be populated with the tensor shapes for FP8
weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ...
"""Construct workspace buffers for FP8 weights, if needed
These workspace buffers are used for FP8 training when the
module parameters are not natively in FP8 and there are
multiple microbatches per training step. The buffers, with
names like `weight1_fp8` and `weight1_t_fp8`, cache the FP8
values and transposed FP8 values in between microbatches. They
are not registered as module parameters or buffers since we
don't want them to be affected by `.to` and since they aren't
needed for checkpointing.
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
......@@ -517,8 +521,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
# only when fp8 weight caching is used and weights are not in fp8
if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment