Unverified Commit 0c74535e authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Restore compatibility with Python 3.8 (#1189)



* Restore compatibility with Python 3.8
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 195d7032
......@@ -354,12 +354,8 @@ class _CheckpointFunction(torch.autograd.Function):
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with (
torch.enable_grad(),
ctx.recompute_ctx,
ctx.torch_gpu_amp_ctx,
ctx.torch_cpu_amp_ctx,
activation_recompute_forward(activation_recompute=True, recompute_phase=True),
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
......@@ -680,13 +676,9 @@ def checkpoint(
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
def recompute_fn(*args, **kwargs):
with (
torch.autograd.enable_grad(),
te_recompute_ctx,
user_recompute_ctx,
torch_gpu_amp_forward_ctx,
torch_cpu_amp_forward_ctx,
):
with torch.autograd.enable_grad(), (
te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx:
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment