"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "c1a1c04e6341778562a45cff847dcaface8b33cc"
Unverified Commit 05f6a691 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Update full recompute feature to save recipe. (#1577)



* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Formatted python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify code by relying on recipe in ctx
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback: import style
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c571c2fd
...@@ -671,8 +671,6 @@ def test_gpt_full_activation_recompute( ...@@ -671,8 +671,6 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model] config = model_configs[model]
......
...@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module ...@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer
...@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args] tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs) ctx.save_for_backward(*tensor_inputs)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
ctx.get_rng_state_tracker = get_rng_state_tracker ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.kwargs = kwargs ctx.kwargs = kwargs
return outputs return outputs
...@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True activation_recompute=True, recompute_phase=True
), fp8_autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
): ):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
...@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary" "none of output has requires_grad=True, this checkpoint() is not necessary"
) )
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch.autograd.backward(outputs_with_grad, args_with_grad) torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple( grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
...@@ -694,10 +702,15 @@ def checkpoint( ...@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase. # Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
def recompute_fn(*args, **kwargs): def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), ( with torch.autograd.enable_grad(), (
te_recompute_ctx te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
enabled=fp8, fp8_recipe=fp8_recipe
):
function(*args, **kwargs) function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass. # Initialize a new checkpoint frame for each new forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment