"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "2bb532fbf33766c089f4193dd6ef745cae5301d3"
Unverified Commit 1669b3f4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add docs for missing FP8 recipes. (#1816)



Document all recipes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b17f3f4e
...@@ -11,3 +11,7 @@ Common API ...@@ -11,3 +11,7 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)
.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID)
.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3)
...@@ -193,42 +193,12 @@ class DelayedScaling(Recipe): ...@@ -193,42 +193,12 @@ class DelayedScaling(Recipe):
class Float8CurrentScaling(Recipe): class Float8CurrentScaling(Recipe):
""" """
Use the per-tensor current scaling factor strategy. Use the per-tensor current scaling factor strategy.
Parameters Parameters
---------- ----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward Controls the FP8 data format used during forward and backward
pass. pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of gradient tensor dY
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
""" """
fp8_format: Format = Format.HYBRID fp8_format: Format = Format.HYBRID
...@@ -243,6 +213,9 @@ class Float8CurrentScaling(Recipe): ...@@ -243,6 +213,9 @@ class Float8CurrentScaling(Recipe):
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8CurrentScaling."
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
...@@ -319,32 +292,12 @@ class Float8BlockScaling(Recipe): ...@@ -319,32 +292,12 @@ class Float8BlockScaling(Recipe):
NOTE: To relax the default constraint that scales be powers of 2, set env variable NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults. NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters Parameters
---------- ----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward Controls the FP8 data format used during forward and backward
pass. pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
""" """
use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1"
...@@ -378,6 +331,9 @@ class Float8BlockScaling(Recipe): ...@@ -378,6 +331,9 @@ class Float8BlockScaling(Recipe):
assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop."
assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad."
assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8BlockScaling."
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment