Unverified Commit f96f3407 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] TE checkpoint pass-through logic fix (#782)



* changed TE checkpoint passthrough logic to also recursively look for TE submodules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* simplified search for TE modules in the checkpointed network
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent d3552ddb
......@@ -498,15 +498,15 @@ def get_activation_recompute_contexts():
return forward_ctx, recompute_ctx
def _is_te_module(module):
def has_te_modules(network):
"""
Check if given module is a Transformer Engine module that requires the TE checkpoint
implementation for activation recompute.
Check if there are any Transformer Engine modules in the network.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
LayerNorm,
RMSNorm,
......@@ -516,12 +516,13 @@ def _is_te_module(module):
MultiheadAttention,
TransformerLayer,
]
is_te_module = False
for te_class in te_classes_list:
if isinstance(module, te_class):
is_te_module = True
break
return is_te_module
if isinstance(network, torch.nn.Module):
for module in network.modules():
if any(isinstance(module, te_class) for te_class in te_classes_list):
return True
return False
def checkpoint(
......@@ -584,14 +585,12 @@ def checkpoint(
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
args = args[3:]
# Trigger the native PyTorch checkpoint if:
# 1. `function` is a `torch.nn.Module`
# AND
# 2. `function` is NOT a TE module
# Trigger the native PyTorch checkpoint if the function is not or does not contain a
# Transformer Engine module.
context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False)
if isinstance(function, torch.nn.Module) and not _is_te_module(function):
if not has_te_modules(function):
return torch.utils.checkpoint.checkpoint(
function,
*args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment