"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "29b0c9cad10d8651578ade7f26172d48b0e235a0"
Unverified Commit d75db5f7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove interval arg from recipe (#892)



* Remove interval arg from recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove usage of interval and use explicit kwarg for testing recipes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c1b915ae
...@@ -85,7 +85,7 @@ PyTorch ...@@ -85,7 +85,7 @@ PyTorch
inp = torch.randn(hidden_size, in_features, device="cuda") inp = torch.randn(hidden_size, in_features, device="cuda")
# Create an FP8 recipe. Note: All input args are optional. # Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
...@@ -120,7 +120,7 @@ Flax ...@@ -120,7 +120,7 @@ Flax
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32) inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional. # Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID) fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
......
...@@ -8,4 +8,4 @@ Common API ...@@ -8,4 +8,4 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format .. autoapiclass:: transformer_engine.common.recipe.Format
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False)) .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
...@@ -25,12 +25,10 @@ class TestFP8Helper(unittest.TestCase): ...@@ -25,12 +25,10 @@ class TestFP8Helper(unittest.TestCase):
def test_initialize(self): def test_initialize(self):
margin = 5.0 margin = 5.0
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10
amax_history_len = 10 amax_history_len = 10
FP8Helper.initialize(margin=margin, FP8Helper.initialize(margin=margin,
fp8_format=fp8_format, fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval,
amax_history_len=amax_history_len) amax_history_len=amax_history_len)
self.assertEqual( self.assertEqual(
...@@ -40,10 +38,6 @@ class TestFP8Helper(unittest.TestCase): ...@@ -40,10 +38,6 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.FP8_FORMAT, fp8_format, FP8Helper.FP8_FORMAT, fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}" f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.") f" but got {FP8Helper.FP8_FORMAT}.")
self.assertEqual(
FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual( self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN, amax_history_len, FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}" f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
...@@ -161,7 +155,6 @@ class TestFP8Functions(unittest.TestCase): ...@@ -161,7 +155,6 @@ class TestFP8Functions(unittest.TestCase):
def _compare_delay_scaling(self, ref, test): def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin) self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.interval == test.interval)
self.assertTrue(ref.fp8_format == test.fp8_format) self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
...@@ -177,14 +170,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -177,14 +170,14 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
...@@ -196,7 +189,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -196,7 +189,7 @@ class TestFP8Functions(unittest.TestCase):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = ( mesh_s = (
(MeshResource(None, None)), (MeshResource(None, None)),
......
...@@ -1197,7 +1197,6 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1197,7 +1197,6 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
...@@ -1357,7 +1356,6 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1357,7 +1356,6 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
...@@ -1557,7 +1555,6 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1557,7 +1555,6 @@ def _run_custom_mha_fp8(dtype, config, backend):
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
......
...@@ -93,7 +93,7 @@ def reset_global_fp8_state(): ...@@ -93,7 +93,7 @@ def reset_global_fp8_state():
def create_fp8_recipe(): def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) return recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
def do_export( def do_export(
......
...@@ -44,7 +44,6 @@ class TestFP8Recipe: ...@@ -44,7 +44,6 @@ class TestFP8Recipe:
fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling( recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin, margin=margin,
interval=1,
fp8_format=fp8_format, fp8_format=fp8_format,
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
......
...@@ -83,28 +83,34 @@ model_configs = { ...@@ -83,28 +83,34 @@ model_configs = {
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case None, # Handles non-FP8 case
recipe.DelayedScaling(0, 1, recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(0, 1, recipe.Format.HYBRID), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling( recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True) margin=0,
fp8_format=recipe.Format.E4M3,
override_linear_precision=(False, False, True),
), ),
recipe.DelayedScaling( recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent" margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
), ),
recipe.DelayedScaling( recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max" margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
), ),
recipe.DelayedScaling( recipe.DelayedScaling(
0, margin=0,
1, fp8_format=recipe.Format.E4M3,
recipe.Format.E4M3,
amax_history_len=16, amax_history_len=16,
amax_compute_algo=custom_amax_compute, amax_compute_algo=custom_amax_compute,
), ),
recipe.DelayedScaling( recipe.DelayedScaling(
0, margin=0,
1, fp8_format=recipe.Format.E4M3,
recipe.Format.E4M3,
amax_history_len=16, amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale, scaling_factor_compute_algo=custom_amax_to_scale,
), ),
...@@ -594,9 +600,8 @@ def test_sanity_gpt_126m(): ...@@ -594,9 +600,8 @@ def test_sanity_gpt_126m():
fp8_recipe = None fp8_recipe = None
if fp8_available: if fp8_available:
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
0, margin=0,
1, fp8_format=recipe.Format.E4M3,
recipe.Format.E4M3,
amax_history_len=16, amax_history_len=16,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
...@@ -657,9 +662,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -657,9 +662,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_bert_126m(): def test_sanity_bert_126m():
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
0, margin=0,
1, fp8_format=recipe.Format.E4M3,
recipe.Format.E4M3,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
...@@ -716,9 +720,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -716,9 +720,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_T5_126m(): def test_sanity_T5_126m():
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
0, margin=0,
1, fp8_format=recipe.Format.E4M3,
recipe.Format.E4M3,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""This module provides predefined FP8 recipes.""" """This module provides predefined FP8 recipes."""
from __future__ import annotations from __future__ import annotations
import warnings
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -52,17 +53,13 @@ class _OverrideLinearPrecision(NamedTuple): ...@@ -52,17 +53,13 @@ class _OverrideLinearPrecision(NamedTuple):
@dataclass() @dataclass()
class DelayedScaling: class DelayedScaling:
""" """
Use the delayed scaling factor strategy. Use the delayed scaling factor strategy. Use scale factor from previous
Use scale factor from previous iteration, iteration and record amax history of `amax_history_len` steps.
recompute once every `interval`, and record
amax history of `amax_history_len` steps.
Parameters Parameters
---------- ----------
margin : int, default = 0 margin : int, default = 0
Margin for the scaling factor computation. Margin for the scaling factor computation.
interval : int, default = 1
Controls how often the scaling factor is recomputed.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward Controls the FP8 data format used during forward and backward
pass. pass.
...@@ -136,7 +133,7 @@ class DelayedScaling: ...@@ -136,7 +133,7 @@ class DelayedScaling:
""" """
margin: int = 0 margin: int = 0
interval: int = 1 interval: int = -1
fp8_format: Format = Format.HYBRID fp8_format: Format = Format.HYBRID
amax_history_len: int = 1024 amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
...@@ -152,11 +149,16 @@ class DelayedScaling: ...@@ -152,11 +149,16 @@ class DelayedScaling:
(False, False, False), (False, False, False),
(False, False, True), (False, False, True),
), "Only wgrad GEMM override is currently supported." ), "Only wgrad GEMM override is currently supported."
if self.interval >= 0:
warnings.warn(
"`interval` argument is deprecated and unused. "
"It will be removed in an upcoming release.",
DeprecationWarning,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"margin={self.margin}, " f"margin={self.margin}, "
f"interval={self.interval}, "
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, " f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, " f"wgrad_override={self.override_linear_precision.wgrad}, "
......
...@@ -171,7 +171,6 @@ class FP8Helper: ...@@ -171,7 +171,6 @@ class FP8Helper:
FP8_FORMAT: Format = Format.HYBRID FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0] FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1] BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
NUM_META_PER_GEMM: int = 3 NUM_META_PER_GEMM: int = 3
...@@ -197,7 +196,6 @@ class FP8Helper: ...@@ -197,7 +196,6 @@ class FP8Helper:
@staticmethod @staticmethod
def initialize(margin: float = 0.0, def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID, fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1,
amax_history_len: int = 1, amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None: amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
""" """
...@@ -208,7 +206,6 @@ class FP8Helper: ...@@ -208,7 +206,6 @@ class FP8Helper:
FP8Helper.FP8_FORMAT = fp8_format FP8Helper.FP8_FORMAT = fp8_format
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT) _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_LEN = amax_history_len FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
FP8Helper.FP8_2X_ACC_FPROP = False FP8Helper.FP8_2X_ACC_FPROP = False
...@@ -225,7 +222,6 @@ class FP8Helper: ...@@ -225,7 +222,6 @@ class FP8Helper:
FP8Helper.FP8_FORMAT = Format.HYBRID FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT) _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024 FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
...@@ -407,11 +403,10 @@ def fp8_autocast(enabled: bool = False, ...@@ -407,11 +403,10 @@ def fp8_autocast(enabled: bool = False,
pjit(transformer.init, ...)(...) pjit(transformer.init, ...)(...)
.. note:: .. note::
We only support :attr:`margin`, :attr:`fp8_format`, We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
:attr:`interval`, :attr:`amax_history_len` and , and :attr:`amax_compute_algo`(with value 'max' and 'most_recent') in
:attr:`amax_compute_algo`(with value 'max' and 'most_recent') recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
in recipe.DelayedScaling currently. Other parameters in will trigger an assertion.
recipe.DelayedScaling will trigger an assertion.
Parameters Parameters
---------- ----------
...@@ -451,7 +446,6 @@ def fp8_autocast(enabled: bool = False, ...@@ -451,7 +446,6 @@ def fp8_autocast(enabled: bool = False,
FP8Helper.initialize(margin=fp8_recipe.margin, FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format, fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval,
amax_history_len=fp8_recipe.amax_history_len, amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo) amax_compute_algo=amax_compute_algo)
yield yield
...@@ -512,10 +506,9 @@ def get_delayed_scaling(): ...@@ -512,10 +506,9 @@ def get_delayed_scaling():
Obtain an instance of DelayedScaling which is set via fp8_autocast. Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note:: .. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`, We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
:attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast. , and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
Other parameters in recipe.DelayedScaling would be returned as the default recipe.DelayedScaling would be returned as the default values.
values.
Returns Returns
------- -------
...@@ -525,7 +518,6 @@ def get_delayed_scaling(): ...@@ -525,7 +518,6 @@ def get_delayed_scaling():
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \ amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent" else "most_recent"
return DelayedScaling(margin=int(FP8Helper.MARGIN), return DelayedScaling(margin=int(FP8Helper.MARGIN),
interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT, fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN, amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo) amax_compute_algo=amax_compute_algo)
...@@ -80,9 +80,7 @@ class FP8State: ...@@ -80,9 +80,7 @@ class FP8State:
@staticmethod @staticmethod
def get_default_fp8_recipe() -> DelayedScaling: def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user """FP8 recipe with default args."""
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling() return DelayedScaling()
def get_autocast_id(self) -> int: def get_autocast_id(self) -> int:
......
...@@ -34,9 +34,7 @@ def check_fp8_support() -> Tuple[bool, str]: ...@@ -34,9 +34,7 @@ def check_fp8_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> DelayedScaling: def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user """FP8 recipe with default args."""
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling() return DelayedScaling()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment