Unverified Commit aee78831 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix for checkpointing for callables. (#1679)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* changed the test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 313ab4f4
......@@ -42,6 +42,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices.
......@@ -1285,3 +1286,31 @@ def test_fp8_model_init_high_precision_init_val():
assert not hasattr(
weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work"
def test_sanity_checkpointing_on_callables():
"""Test that TE checkpointing works correctly on callable modules."""
# torch.autograf.function
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp
@staticmethod
def backward(ctx, grad_output):
return grad_output
module = MyFunction.apply
inp = torch.randn(10, 10, device="cuda", requires_grad=True)
out_checkpoint = checkpoint(module, inp)
out_checkpoint.sum().backward()
grad_checkpoint = inp.grad
out_standard = module(inp)
out_standard.sum().backward()
grad_standard = inp.grad
# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
......@@ -661,10 +661,13 @@ def checkpoint(
**kwargs,
)
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
from .module.base import TransformerEngineBaseModule
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment