"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "fe8fad59c9add7f8aa841fed2a0b4087b931856f"
Unverified Commit 05f6a691 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Update full recompute feature to save recipe. (#1577)



* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Formatted python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify code by relying on recipe in ctx
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback: import style
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c571c2fd
......@@ -671,8 +671,6 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model]
......
......@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
......@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.kwargs = kwargs
return outputs
......@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True
), fp8_autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
......@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
......@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), (
te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx:
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
enabled=fp8, fp8_recipe=fp8_recipe
):
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment