Unverified Commit a5181512 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix FP8 activation recompute (#1254)



Fix FP8 activation recompute
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6e90fcb7
...@@ -206,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator): ...@@ -206,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator):
activations, followed by calculation of gradients using these values. activations, followed by calculation of gradients using these values.
""" """
_is_first_fp8_module: List = []
def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False): def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
super().__init__() super().__init__()
self.activation_recompute = activation_recompute self.activation_recompute = activation_recompute
...@@ -218,6 +220,15 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator): ...@@ -218,6 +220,15 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator):
) )
_FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase
if self.activation_recompute and not self.recompute_phase:
activation_recompute_forward._is_first_fp8_module.append(
FP8GlobalStateManager.IS_FIRST_FP8_MODULE
)
if self.activation_recompute and self.recompute_phase:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
activation_recompute_forward._is_first_fp8_module.pop(0)
)
def __exit__(self, *exc_details): def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
......
...@@ -36,6 +36,7 @@ from ..distributed import ( ...@@ -36,6 +36,7 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
...@@ -361,10 +362,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -361,10 +362,10 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.normalization = normalization ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = ( _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
or FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase():
) FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
......
...@@ -43,6 +43,7 @@ from ..distributed import ( ...@@ -43,6 +43,7 @@ from ..distributed import (
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
use_reentrant_activation_recompute, use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
...@@ -516,7 +517,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -516,7 +517,10 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8 and requires_grad( if ctx.fp8 and requires_grad(
inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias
): ):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
......
...@@ -33,6 +33,7 @@ from ..distributed import ( ...@@ -33,6 +33,7 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
...@@ -349,10 +350,10 @@ class _Linear(torch.autograd.Function): ...@@ -349,10 +350,10 @@ class _Linear(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias): if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = ( _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
or FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase():
) FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment