Unverified Commit aee78831 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix for checkpointing for callables. (#1679)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* changed the test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 313ab4f4
...@@ -42,6 +42,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -42,6 +42,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
...@@ -1285,3 +1286,31 @@ def test_fp8_model_init_high_precision_init_val(): ...@@ -1285,3 +1286,31 @@ def test_fp8_model_init_high_precision_init_val():
assert not hasattr( assert not hasattr(
weight, "._high_precision_init_val" weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work" ), "clear_high_precision_init_val() not work"
def test_sanity_checkpointing_on_callables():
"""Test that TE checkpointing works correctly on callable modules."""
# torch.autograf.function
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp
@staticmethod
def backward(ctx, grad_output):
return grad_output
module = MyFunction.apply
inp = torch.randn(10, 10, device="cuda", requires_grad=True)
out_checkpoint = checkpoint(module, inp)
out_checkpoint.sum().backward()
grad_checkpoint = inp.grad
out_standard = module(inp)
out_standard.sum().backward()
grad_standard = inp.grad
# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
...@@ -661,6 +661,9 @@ def checkpoint( ...@@ -661,6 +661,9 @@ def checkpoint(
**kwargs, **kwargs,
) )
from .module.base import TransformerEngineBaseModule
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway. # to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False) setattr(function, "fsdp_wrapped", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment