"examples/jax/vscode:/vscode.git/clone" did not exist on "aaf9354861c447eb5d53cb2907572490c0673861"
Unverified Commit f96f3407 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] TE checkpoint pass-through logic fix (#782)



* changed TE checkpoint passthrough logic to also recursively look for TE submodules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* simplified search for TE modules in the checkpointed network
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent d3552ddb
...@@ -498,15 +498,15 @@ def get_activation_recompute_contexts(): ...@@ -498,15 +498,15 @@ def get_activation_recompute_contexts():
return forward_ctx, recompute_ctx return forward_ctx, recompute_ctx
def _is_te_module(module): def has_te_modules(network):
""" """
Check if given module is a Transformer Engine module that requires the TE checkpoint Check if there are any Transformer Engine modules in the network.
implementation for activation recompute.
""" """
from .module import LayerNorm, RMSNorm from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
te_classes_list = [ te_classes_list = [
LayerNorm, LayerNorm,
RMSNorm, RMSNorm,
...@@ -516,12 +516,13 @@ def _is_te_module(module): ...@@ -516,12 +516,13 @@ def _is_te_module(module):
MultiheadAttention, MultiheadAttention,
TransformerLayer, TransformerLayer,
] ]
is_te_module = False
for te_class in te_classes_list: if isinstance(network, torch.nn.Module):
if isinstance(module, te_class): for module in network.modules():
is_te_module = True if any(isinstance(module, te_class) for te_class in te_classes_list):
break return True
return is_te_module
return False
def checkpoint( def checkpoint(
...@@ -584,14 +585,12 @@ def checkpoint( ...@@ -584,14 +585,12 @@ def checkpoint(
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
args = args[3:] args = args[3:]
# Trigger the native PyTorch checkpoint if: # Trigger the native PyTorch checkpoint if the function is not or does not contain a
# 1. `function` is a `torch.nn.Module` # Transformer Engine module.
# AND
# 2. `function` is NOT a TE module
context_fn = kwargs.pop("context_fn", noop_context_fn) context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default") determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False) debug = kwargs.pop("debug", False)
if isinstance(function, torch.nn.Module) and not _is_te_module(function): if not has_te_modules(function):
return torch.utils.checkpoint.checkpoint( return torch.utils.checkpoint.checkpoint(
function, function,
*args, *args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment