Unverified Commit c8175d9e authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Refactor FP8 state (#350)



Refactor fp8 state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 95ec1560
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
import copy
from contextlib import contextmanager
from typing import Tuple, Optional, Dict, Any
......@@ -15,10 +16,6 @@ from transformer_engine.common.recipe import DelayedScaling, Format
# FP8 support
_is_fp8_available = None
_reason_for_no_fp8 = ""
# FP8 status
_FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None
def _check_fp8_support() -> Tuple[bool, str]:
......@@ -49,29 +46,42 @@ def is_fp8_available() -> Tuple[bool, str]:
return _is_fp8_available, _reason_for_no_fp8
# Functions used to access fp8 status
def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
class FP8State:
"""Stores FP8 state"""
def __init__(self):
self.fp8_enabled = False
self.fp8_calibration = False
self.fp8_recipe = None
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_fp8_enabled(self) -> bool:
"""Is FP8 enabled"""
return self.fp8_enabled
def is_fp8_calibration(self) -> bool:
"""Is FP8 calibration"""
return self.fp8_calibration
def get_fp8_recipe() -> DelayedScaling:
def get_fp8_recipe(self) -> DelayedScaling:
"""Return the fp8 recipe"""
return _FP8_RECIPE
return self.fp8_recipe
def get_default_fp8_recipe() -> DelayedScaling:
@staticmethod
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
_global_fp8_state = FP8State()
def get_global_fp8_state() -> FP8State:
"""Get global fp8 state"""
return _global_fp8_state
@contextmanager
def fp8_autocast(
enabled: bool = False,
......@@ -82,19 +92,20 @@ def fp8_autocast(
Context manager for FP8 usage.
"""
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE)
global _global_fp8_state
saved_fp8_state = copy.deepcopy(_global_fp8_state)
try:
_FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_global_fp8_state.fp8_enabled = enabled
_global_fp8_state.fp8_calibration = calibrating
_global_fp8_state.fp8_recipe = FP8State.get_default_fp8_recipe(
) if fp8_recipe is None else fp8_recipe
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
yield
finally:
(_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE) = fp8_state
_global_fp8_state = saved_fp8_state
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
......
......@@ -17,13 +17,11 @@ from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8
from ..fp8 import (
get_fp8_recipe,
get_default_fp8_recipe,
is_fp8_enabled,
is_fp8_calibration,
FP8State,
FP8TensorMeta,
amax_and_scale_update,
get_global_fp8_state,
get_fp8_te_dtype,
FP8TensorMeta,
)
from ..profile import nvtx_range
from ..utils import get_bias_dtype, cast_if_needed
......@@ -63,7 +61,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe()
self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True)
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
......@@ -104,17 +102,18 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8_enabled = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
state = get_global_fp8_state()
self.fp8_enabled = state.is_fp8_enabled()
self.fp8_calibration = state.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration
if self.fp8_enabled or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
if self.fp8_initialized and state.get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["recipe"] = state.get_fp8_recipe()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment