"vscode:/vscode.git/clone" did not exist on "b72c79f96469fa3a22aaa3cfdff35ac8622d34da"
Unverified Commit c8175d9e authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Refactor FP8 state (#350)



Refactor fp8 state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 95ec1560
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""FP8 utilities for TransformerEngine""" """FP8 utilities for TransformerEngine"""
import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import Tuple, Optional, Dict, Any from typing import Tuple, Optional, Dict, Any
...@@ -15,10 +16,6 @@ from transformer_engine.common.recipe import DelayedScaling, Format ...@@ -15,10 +16,6 @@ from transformer_engine.common.recipe import DelayedScaling, Format
# FP8 support # FP8 support
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
# FP8 status
_FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None
def _check_fp8_support() -> Tuple[bool, str]: def _check_fp8_support() -> Tuple[bool, str]:
...@@ -49,27 +46,40 @@ def is_fp8_available() -> Tuple[bool, str]: ...@@ -49,27 +46,40 @@ def is_fp8_available() -> Tuple[bool, str]:
return _is_fp8_available, _reason_for_no_fp8 return _is_fp8_available, _reason_for_no_fp8
# Functions used to access fp8 status class FP8State:
def is_fp8_enabled() -> bool: """Stores FP8 state"""
"""Is FP8 enabled"""
return _FP8_ENABLED
def __init__(self):
self.fp8_enabled = False
self.fp8_calibration = False
self.fp8_recipe = None
def is_fp8_calibration() -> bool: def is_fp8_enabled(self) -> bool:
"""Is FP8 calibration""" """Is FP8 enabled"""
return _FP8_CALIBRATION return self.fp8_enabled
def is_fp8_calibration(self) -> bool:
"""Is FP8 calibration"""
return self.fp8_calibration
def get_fp8_recipe() -> DelayedScaling: def get_fp8_recipe(self) -> DelayedScaling:
"""Return the fp8 recipe""" """Return the fp8 recipe"""
return _FP8_RECIPE return self.fp8_recipe
@staticmethod
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user _global_fp8_state = FP8State()
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling() def get_global_fp8_state() -> FP8State:
"""Get global fp8 state"""
return _global_fp8_state
@contextmanager @contextmanager
...@@ -82,19 +92,20 @@ def fp8_autocast( ...@@ -82,19 +92,20 @@ def fp8_autocast(
Context manager for FP8 usage. Context manager for FP8 usage.
""" """
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE global _global_fp8_state
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE) saved_fp8_state = copy.deepcopy(_global_fp8_state)
try: try:
_FP8_ENABLED = enabled _global_fp8_state.fp8_enabled = enabled
_FP8_CALIBRATION = calibrating _global_fp8_state.fp8_calibration = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe _global_fp8_state.fp8_recipe = FP8State.get_default_fp8_recipe(
) if fp8_recipe is None else fp8_recipe
if enabled: if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8 assert fp8_available, reason_for_no_fp8
yield yield
finally: finally:
(_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE) = fp8_state _global_fp8_state = saved_fp8_state
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
......
...@@ -17,13 +17,11 @@ from paddle.fluid.framework import _dygraph_tracer ...@@ -17,13 +17,11 @@ from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors from ..constants import FP8BwdTensors
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8 from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8
from ..fp8 import ( from ..fp8 import (
get_fp8_recipe, FP8State,
get_default_fp8_recipe, FP8TensorMeta,
is_fp8_enabled,
is_fp8_calibration,
amax_and_scale_update, amax_and_scale_update,
get_global_fp8_state,
get_fp8_te_dtype, get_fp8_te_dtype,
FP8TensorMeta,
) )
from ..profile import nvtx_range from ..profile import nvtx_range
from ..utils import get_bias_dtype, cast_if_needed from ..utils import get_bias_dtype, cast_if_needed
...@@ -63,7 +61,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -63,7 +61,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_calibration = False self.fp8_calibration = False
self.fp8_meta = {} self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe()
self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True)
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
...@@ -104,17 +102,18 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -104,17 +102,18 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
# assume FP8 execution. # assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
self.fp8_enabled = is_fp8_enabled() state = get_global_fp8_state()
self.fp8_calibration = is_fp8_calibration() self.fp8_enabled = state.is_fp8_enabled()
self.fp8_calibration = state.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration
if self.fp8_enabled or self.fp8_calibration: if self.fp8_enabled or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: if self.fp8_initialized and state.get_fp8_recipe() == self.fp8_meta["recipe"]:
return return
# Set FP8, recipe, and other FP8 metadata # Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_meta["recipe"] = state.get_fp8_recipe()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment