Unverified Commit d75db5f7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove interval arg from recipe (#892)



* Remove interval arg from recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove usage of interval and use explicit kwarg for testing recipes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c1b915ae
......@@ -85,7 +85,7 @@ PyTorch
inp = torch.randn(hidden_size, in_features, device="cuda")
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
......@@ -120,7 +120,7 @@ Flax
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
......
......@@ -8,4 +8,4 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
......@@ -25,12 +25,10 @@ class TestFP8Helper(unittest.TestCase):
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10
amax_history_len = 10
FP8Helper.initialize(margin=margin,
fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval,
amax_history_len=amax_history_len)
self.assertEqual(
......@@ -40,10 +38,6 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.FP8_FORMAT, fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.")
self.assertEqual(
FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
......@@ -161,7 +155,6 @@ class TestFP8Functions(unittest.TestCase):
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.interval == test.interval)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
......@@ -177,14 +170,14 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
......@@ -196,7 +189,7 @@ class TestFP8Functions(unittest.TestCase):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
......
......@@ -1197,7 +1197,6 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
......@@ -1357,7 +1356,6 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
......@@ -1557,7 +1555,6 @@ def _run_custom_mha_fp8(dtype, config, backend):
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
......
......@@ -93,7 +93,7 @@ def reset_global_fp8_state():
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
return recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
def do_export(
......
......@@ -44,7 +44,6 @@ class TestFP8Recipe:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
......
......@@ -83,28 +83,34 @@ model_configs = {
fp8_recipes = [
None, # Handles non-FP8 case
recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
margin=0,
fp8_format=recipe.Format.E4M3,
override_linear_precision=(False, False, True),
),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale,
),
......@@ -594,9 +600,8 @@ def test_sanity_gpt_126m():
fp8_recipe = None
if fp8_available:
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
)
......@@ -657,9 +662,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_bert_126m():
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
......@@ -716,9 +720,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_T5_126m():
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
......
......@@ -4,6 +4,7 @@
"""This module provides predefined FP8 recipes."""
from __future__ import annotations
import warnings
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass
......@@ -52,17 +53,13 @@ class _OverrideLinearPrecision(NamedTuple):
@dataclass()
class DelayedScaling:
"""
Use the delayed scaling factor strategy.
Use scale factor from previous iteration,
recompute once every `interval`, and record
amax history of `amax_history_len` steps.
Use the delayed scaling factor strategy. Use scale factor from previous
iteration and record amax history of `amax_history_len` steps.
Parameters
----------
margin : int, default = 0
Margin for the scaling factor computation.
interval : int, default = 1
Controls how often the scaling factor is recomputed.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
......@@ -136,7 +133,7 @@ class DelayedScaling:
"""
margin: int = 0
interval: int = 1
interval: int = -1
fp8_format: Format = Format.HYBRID
amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
......@@ -152,11 +149,16 @@ class DelayedScaling:
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
if self.interval >= 0:
warnings.warn(
"`interval` argument is deprecated and unused. "
"It will be removed in an upcoming release.",
DeprecationWarning,
)
def __repr__(self) -> str:
return (
f"margin={self.margin}, "
f"interval={self.interval}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
......
......@@ -171,7 +171,6 @@ class FP8Helper:
FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
NUM_META_PER_GEMM: int = 3
......@@ -197,7 +196,6 @@ class FP8Helper:
@staticmethod
def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1,
amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
"""
......@@ -208,7 +206,6 @@ class FP8Helper:
FP8Helper.FP8_FORMAT = fp8_format
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
FP8Helper.FP8_2X_ACC_FPROP = False
......@@ -225,7 +222,6 @@ class FP8Helper:
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
......@@ -407,11 +403,10 @@ def fp8_autocast(enabled: bool = False,
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`,
:attr:`interval`, :attr:`amax_history_len` and
:attr:`amax_compute_algo`(with value 'max' and 'most_recent')
in recipe.DelayedScaling currently. Other parameters in
recipe.DelayedScaling will trigger an assertion.
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo`(with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
......@@ -451,7 +446,6 @@ def fp8_autocast(enabled: bool = False,
FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval,
amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo)
yield
......@@ -512,10 +506,9 @@ def get_delayed_scaling():
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
:attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
Other parameters in recipe.DelayedScaling would be returned as the default
values.
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
......@@ -525,7 +518,6 @@ def get_delayed_scaling():
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent"
return DelayedScaling(margin=int(FP8Helper.MARGIN),
interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo)
......@@ -80,9 +80,7 @@ class FP8State:
@staticmethod
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
"""FP8 recipe with default args."""
return DelayedScaling()
def get_autocast_id(self) -> int:
......
......@@ -34,9 +34,7 @@ def check_fp8_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
"""FP8 recipe with default args."""
return DelayedScaling()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment