Unverified Commit 7c4887b2 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Support `torch.amp.autocast` in TE checkpoint (#791)



TE checkpoint now preserves the torch autocast context from the forward pass during the recompute phase
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 82e5b4d2
......@@ -228,6 +228,26 @@ def in_fp8_activation_recompute_phase() -> bool:
return _FP8_ACTIVATION_RECOMPUTE_PHASE
def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached)
return gpu_autocast_ctx, cpu_autocast_ctx
class _CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
......@@ -262,6 +282,10 @@ class _CheckpointFunction(torch.autograd.Function):
forward_ctx, recompute_ctx = context_fn()
else:
forward_ctx, recompute_ctx = noop_context_fn()
# Preserve torch autocast context for the backward pass
torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()
with torch.no_grad(), forward_ctx:
with activation_recompute_forward(
activation_recompute=True, recompute_phase=False
......@@ -287,6 +311,8 @@ class _CheckpointFunction(torch.autograd.Function):
ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.kwargs = kwargs
return outputs
......@@ -331,11 +357,11 @@ class _CheckpointFunction(torch.autograd.Function):
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx:
with activation_recompute_forward(
activation_recompute=True, recompute_phase=True
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
with (torch.enable_grad(), ctx.recompute_ctx,
ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx,
activation_recompute_forward(
activation_recompute=True, recompute_phase=True)):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
......@@ -639,8 +665,13 @@ def checkpoint(
user_forward_ctx, user_recompute_ctx = context_fn()
te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), te_recompute_ctx, user_recompute_ctx:
with (torch.autograd.enable_grad(),
te_recompute_ctx, user_recompute_ctx,
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx):
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
......@@ -650,7 +681,8 @@ def checkpoint(
)
new_frame.cache_rng_states(forward=True)
with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
with (_checkpoint_hook(new_frame, args, kwargs),
te_forward_ctx, user_forward_ctx):
out = function(*args, **kwargs)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment