Unverified Commit a5181512 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix FP8 activation recompute (#1254)



Fix FP8 activation recompute
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6e90fcb7
......@@ -206,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator):
activations, followed by calculation of gradients using these values.
"""
_is_first_fp8_module: List = []
def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
super().__init__()
self.activation_recompute = activation_recompute
......@@ -218,6 +220,15 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator):
)
_FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase
if self.activation_recompute and not self.recompute_phase:
activation_recompute_forward._is_first_fp8_module.append(
FP8GlobalStateManager.IS_FIRST_FP8_MODULE
)
if self.activation_recompute and self.recompute_phase:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
activation_recompute_forward._is_first_fp8_module.pop(0)
)
def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
......
......@@ -36,6 +36,7 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -361,10 +362,10 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......
......@@ -43,6 +43,7 @@ from ..distributed import (
reduce_scatter_along_first_dim,
gather_along_first_dim,
use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -516,7 +517,10 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8 and requires_grad(
inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias
):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear
if ub_overlap_rs:
......
......@@ -33,6 +33,7 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -349,10 +350,10 @@ class _Linear(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear
if ub_overlap_rs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment