Unverified Commit 0c74535e authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Restore compatibility with Python 3.8 (#1189)



* Restore compatibility with Python 3.8
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 195d7032
...@@ -354,12 +354,8 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -354,12 +354,8 @@ class _CheckpointFunction(torch.autograd.Function):
# Compute the forward pass. # Compute the forward pass.
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with ( with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
torch.enable_grad(), activation_recompute=True, recompute_phase=True
ctx.recompute_ctx,
ctx.torch_gpu_amp_ctx,
ctx.torch_cpu_amp_ctx,
activation_recompute_forward(activation_recompute=True, recompute_phase=True),
): ):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
...@@ -680,13 +676,9 @@ def checkpoint( ...@@ -680,13 +676,9 @@ def checkpoint(
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
def recompute_fn(*args, **kwargs): def recompute_fn(*args, **kwargs):
with ( with torch.autograd.enable_grad(), (
torch.autograd.enable_grad(), te_recompute_ctx
te_recompute_ctx, ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx:
user_recompute_ctx,
torch_gpu_amp_forward_ctx,
torch_cpu_amp_forward_ctx,
):
function(*args, **kwargs) function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass. # Initialize a new checkpoint frame for each new forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment