Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
......
...@@ -12,18 +12,15 @@ import torch ...@@ -12,18 +12,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.quantization import FP8GlobalStateManager
FP8GlobalStateManager,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
attention_mask_func, attention_mask_func,
is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
autocast,
quantized_model_init,
DotProductAttention, DotProductAttention,
LayerNormLinear, LayerNormLinear,
LayerNormMLP, LayerNormMLP,
...@@ -35,26 +32,28 @@ from transformer_engine.pytorch import ( ...@@ -35,26 +32,28 @@ from transformer_engine.pytorch import (
LayerNorm, LayerNorm,
Fp8Padding, Fp8Padding,
Fp8Unpadding, Fp8Unpadding,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
MXFP8Quantizer,
get_device_compute_capability,
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_bf16_available,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available = is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -77,7 +76,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"] ...@@ -77,7 +76,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"] input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
batch_sizes = [1, 2] batch_sizes = [1, 2]
...@@ -548,7 +547,7 @@ def _test_e2e_selective_recompute( ...@@ -548,7 +547,7 @@ def _test_e2e_selective_recompute(
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -575,7 +574,7 @@ def _test_e2e_selective_recompute( ...@@ -575,7 +574,7 @@ def _test_e2e_selective_recompute(
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
...@@ -637,7 +636,7 @@ def _test_e2e_full_recompute( ...@@ -637,7 +636,7 @@ def _test_e2e_full_recompute(
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -665,7 +664,7 @@ def _test_e2e_full_recompute( ...@@ -665,7 +664,7 @@ def _test_e2e_full_recompute(
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if recompute: if recompute:
te_out = te_checkpoint( te_out = te_checkpoint(
block, block,
...@@ -1088,7 +1087,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, ...@@ -1088,7 +1087,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
out = block(inp_hidden_states) out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)): if isinstance(out, (List, Tuple)):
out = out[0] out = out[0]
...@@ -1304,7 +1303,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1304,7 +1303,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear( te_linear_ref = Linear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -1758,7 +1757,7 @@ def _test_grouped_linear_accuracy( ...@@ -1758,7 +1757,7 @@ def _test_grouped_linear_accuracy(
else: else:
m_splits = torch.tensor([config.max_seqlen_q]) m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, GroupedLinear): if isinstance(block, GroupedLinear):
m_splits = m_splits * bs m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist()) out = block(inp_hidden_states, m_splits.tolist())
...@@ -1820,7 +1819,7 @@ def test_grouped_linear_accuracy( ...@@ -1820,7 +1819,7 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -1956,7 +1955,7 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1956,7 +1955,7 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2110,7 +2109,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -2110,7 +2109,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with autocast(enabled=fp8, recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding): if isinstance(block, TorchGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits) out = block(inp_hidden_states, m_splits)
else: else:
...@@ -2158,7 +2157,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -2158,7 +2157,7 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2169,7 +2168,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -2169,7 +2168,7 @@ def test_padding_grouped_linear_accuracy(
fp8=fp8, fp8=fp8,
).eval() ).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear( ref_grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2229,7 +2228,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2229,7 +2228,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2240,7 +2239,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2240,7 +2239,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8=fp8, fp8=fp8,
).eval() ).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear( ref_grouped_linear = GroupedLinear(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
...@@ -2390,7 +2389,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2390,7 +2389,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8_model_params, recipe=recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -2417,7 +2416,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2417,7 +2416,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe): with autocast(enabled=True, recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
......
...@@ -34,7 +34,7 @@ import transformer_engine.pytorch as te ...@@ -34,7 +34,7 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt import tensorrt as trt
...@@ -57,8 +57,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( ...@@ -57,8 +57,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
# The directory where this file is stored. # The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
...@@ -178,8 +178,8 @@ def do_export( ...@@ -178,8 +178,8 @@ def do_export(
input_names = input_names or ["input"] input_names = input_names or ["input"]
output_names = output_names or ["output"] output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast( with torch.inference_mode(), te.autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe enabled=fp8_recipe is not None, recipe=fp8_recipe
), warnings.catch_warnings(): ), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
...@@ -233,8 +233,8 @@ def te_infer( ...@@ -233,8 +233,8 @@ def te_infer(
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
): ):
"""Transformer Engine forward propagation.""" """Transformer Engine forward propagation."""
with torch.inference_mode(), te.fp8_autocast( with torch.inference_mode(), te.autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe enabled=is_fp8, recipe=fp8_recipe
), warnings.catch_warnings(): ), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple): if not isinstance(te_outputs, tuple):
...@@ -440,7 +440,7 @@ def _test_export_linear( ...@@ -440,7 +440,7 @@ def _test_export_linear(
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda" device="cuda"
) )
...@@ -500,7 +500,7 @@ def _test_export_layernorm( ...@@ -500,7 +500,7 @@ def _test_export_layernorm(
fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx" fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"
with torch.no_grad(): with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
model = layernorm_cls( model = layernorm_cls(
hidden_size, hidden_size,
...@@ -568,7 +568,7 @@ def _test_export_layernorm_linear( ...@@ -568,7 +568,7 @@ def _test_export_layernorm_linear(
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with torch.no_grad(): with torch.no_grad():
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormLinear( model = te.LayerNormLinear(
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
...@@ -654,7 +654,7 @@ def _test_export_layernorm_mlp( ...@@ -654,7 +654,7 @@ def _test_export_layernorm_mlp(
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
model = te.LayerNormMLP( model = te.LayerNormMLP(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -1160,13 +1160,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): ...@@ -1160,13 +1160,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
out_ref = model(*inps) out_ref = model(*inps)
onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
os.close(onnx_fd) os.close(onnx_fd)
try: try:
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
with te.onnx_export(enabled=True): with te.onnx_export(enabled=True):
torch.onnx.export( torch.onnx.export(
model, model,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import random import random
import torch import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch import parallel_cross_entropy
from utils import dtype_tols from utils import dtype_tols
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import pytest import pytest
from typing import Dict, List from typing import Dict, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
moe_permute as te_permute, moe_permute as te_permute,
...@@ -16,14 +17,12 @@ from transformer_engine.pytorch import ( ...@@ -16,14 +17,12 @@ from transformer_engine.pytorch import (
moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
) )
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch import (
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
) )
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex import transformer_engine_torch as tex
import copy import copy
...@@ -1119,7 +1118,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): ...@@ -1119,7 +1118,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn):
# TE tensor dtypes # TE tensor dtypes
_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] _te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
if is_bf16_compatible(): if te.is_bf16_available():
_te_dtypes.append(tex.DType.kBFloat16) _te_dtypes.append(tex.DType.kBFloat16)
...@@ -1239,10 +1238,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): ...@@ -1239,10 +1238,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
fp8_recipes = [ fp8_recipes = [
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Iterable, Optional from typing import Optional
import pytest import pytest
import torch import torch
...@@ -10,28 +10,34 @@ import warnings ...@@ -10,28 +10,34 @@ import warnings
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer Float8BlockQuantizer,
MXFP8Quantizer,
Float8Quantizer,
NVFP4Quantizer,
quantized_model_init,
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
)
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager, FP8GlobalStateManager,
_amax_and_scale_update, _amax_and_scale_update,
fp8_model_init,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
fp4_available, reason_for_no_fp4 = te.is_nvfp4_available(return_reason=True)
# FP8 per tensor delayed scaling # FP8 per tensor delayed scaling
...@@ -64,7 +70,7 @@ class TestFP8Recipe: ...@@ -64,7 +70,7 @@ class TestFP8Recipe:
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
module = te.Linear(16, 16) module = te.Linear(16, 16)
y = module( y = module(
torch.randn([16, 16], device="cuda"), torch.randn([16, 16], device="cuda"),
...@@ -120,7 +126,7 @@ class TestFP8Recipe: ...@@ -120,7 +126,7 @@ class TestFP8Recipe:
# ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Perform forward, backward, and optimizer steps to update fp8_meta # Perform forward, backward, and optimizer steps to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.autocast(enabled=True, recipe=recipe):
x = torch.randn([16, 16], device="cuda") x = torch.randn([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch) y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.randn_like(y)) y.backward(torch.randn_like(y))
...@@ -219,7 +225,7 @@ class TestFP8Recipe: ...@@ -219,7 +225,7 @@ class TestFP8Recipe:
op.weight.fill_(w_history[-1]) op.weight.fill_(w_history[-1])
# Forward and backward pass # Forward and backward pass
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
y = op(x) y = op(x)
y.backward(dy) y.backward(dy)
...@@ -301,7 +307,7 @@ class TestFP8Recipe: ...@@ -301,7 +307,7 @@ class TestFP8Recipe:
scaling_factor_compute_algo = None scaling_factor_compute_algo = None
if fused_update: if fused_update:
scaling_factor_compute_algo = ( scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute( lambda amax, scale, fp8_max, recipe: te.quantization._default_sf_compute(
amax, scale, fp8_max, recipe.margin amax, scale, fp8_max, recipe.margin
) )
) )
...@@ -311,7 +317,7 @@ class TestFP8Recipe: ...@@ -311,7 +317,7 @@ class TestFP8Recipe:
# Setup fp8_meta dictionary # Setup fp8_meta dictionary
def setup_fp8_meta(): def setup_fp8_meta():
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
module = te.Linear(16, 16) module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda")) y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
...@@ -393,11 +399,11 @@ class TestFP8Recipe: ...@@ -393,11 +399,11 @@ class TestFP8Recipe:
], ],
) )
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe): with quantized_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda() linear = Linear(32, 32).cuda()
x = torch.randn(32, 32, device="cuda") x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): with te.autocast(enabled=True, recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
_ = linear(x) _ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value) assert "Recipe mismatch for " in str(excinfo.value)
...@@ -436,7 +442,7 @@ class TestFP8Recipe: ...@@ -436,7 +442,7 @@ class TestFP8Recipe:
# Run initial iterations with DelayedScaling # Run initial iterations with DelayedScaling
for _ in range(3): for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda") x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): with te.autocast(enabled=True, recipe=initial_recipe):
y = linear(x) y = linear(x)
loss = y.mean() loss = y.mean()
loss.backward() loss.backward()
...@@ -453,7 +459,7 @@ class TestFP8Recipe: ...@@ -453,7 +459,7 @@ class TestFP8Recipe:
if i == 0: if i == 0:
# Expect a warning on the first iteration with the new recipe # Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"): with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe): with te.autocast(enabled=True, recipe=target_recipe):
y = linear(x) y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]: for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type) assert isinstance(quantizer, expected_quantizer_type)
...@@ -461,7 +467,7 @@ class TestFP8Recipe: ...@@ -461,7 +467,7 @@ class TestFP8Recipe:
# No warning expected on subsequent iterations # No warning expected on subsequent iterations
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe): with te.autocast(enabled=True, recipe=target_recipe):
y = linear(x) y = linear(x)
loss = y.mean() loss = y.mean()
loss.backward() loss.backward()
...@@ -485,7 +491,7 @@ class TestFP8Recipe: ...@@ -485,7 +491,7 @@ class TestFP8Recipe:
batch_size = 32 batch_size = 32
recipe = DelayedScaling(amax_history_len=1024) recipe = DelayedScaling(amax_history_len=1024)
with fp8_model_init(recipe=recipe): with quantized_model_init(recipe=recipe):
if module_class == GroupedLinear: if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda() module = module_class(1, in_features, out_features).cuda()
else: else:
...@@ -493,7 +499,7 @@ class TestFP8Recipe: ...@@ -493,7 +499,7 @@ class TestFP8Recipe:
x = torch.randn(batch_size, in_features, device="cuda") x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1) recipe = DelayedScaling(amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=recipe): with te.autocast(enabled=True, recipe=recipe):
warn_msg = "Quantizer is being updated, this may affect model behavior" warn_msg = "Quantizer is being updated, this may affect model behavior"
with pytest.warns(UserWarning, match=warn_msg): with pytest.warns(UserWarning, match=warn_msg):
if module_class == GroupedLinear: if module_class == GroupedLinear:
...@@ -502,9 +508,6 @@ class TestFP8Recipe: ...@@ -502,9 +508,6 @@ class TestFP8Recipe:
y = module(x) y = module(x)
fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available()
@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -8,18 +8,16 @@ import torch ...@@ -8,18 +8,16 @@ import torch
import pytest import pytest
import os import os
import transformer_engine.pytorch import transformer_engine
from transformer_engine.pytorch.fp8 import ( import transformer_engine.pytorch as te
fp8_autocast, from transformer_engine.pytorch.quantization import FP8GlobalStateManager
FP8GlobalStateManager,
fp8_model_init,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
autocast,
quantized_model_init,
LayerNormLinear, LayerNormLinear,
Linear, Linear,
GroupedLinear, GroupedLinear,
...@@ -27,26 +25,25 @@ from transformer_engine.pytorch import ( ...@@ -27,26 +25,25 @@ from transformer_engine.pytorch import (
TransformerLayer, TransformerLayer,
RMSNorm, RMSNorm,
LayerNorm, LayerNorm,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
MXFP8Tensor,
checkpoint,
QuantizedTensor,
is_bf16_available,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import ModelConfig from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Record initial RNG state from script run. # Record initial RNG state from script run.
seed = 1234 seed = 1234
...@@ -108,7 +105,7 @@ if fp8_available: ...@@ -108,7 +105,7 @@ if fp8_available:
fp8_recipes.append(None) fp8_recipes.append(None)
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
all_boolean = [True, False] all_boolean = [True, False]
...@@ -160,7 +157,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -160,7 +157,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
...@@ -199,7 +196,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci ...@@ -199,7 +196,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
p.main_grad = torch.zeros_like(p) p.main_grad = torch.zeros_like(p)
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -227,7 +224,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -227,7 +224,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block) _disable_wgrads(block)
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states) te_out = block(te_inp_hidden_states)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -253,7 +250,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -253,7 +250,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block) _disable_wgrads(block)
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -285,7 +282,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -285,7 +282,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads(block) _disable_wgrads(block)
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
...@@ -314,7 +311,7 @@ def _test_sanity_common( ...@@ -314,7 +311,7 @@ def _test_sanity_common(
_disable_wgrads(block) _disable_wgrads(block)
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
if not microbatching: if not microbatching:
te_out = block(te_inp) te_out = block(te_inp)
else: else:
...@@ -455,7 +452,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -455,7 +452,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_linear = Linear( te_linear = Linear(
config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda() ).cuda()
...@@ -463,7 +460,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -463,7 +460,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
out = te_linear(inp_hidden_states) out = te_linear(inp_hidden_states)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -496,7 +493,7 @@ def test_sanity_grouped_linear( ...@@ -496,7 +493,7 @@ def test_sanity_grouped_linear(
pytest.skip("NVFP4 not supported for grouped linear") pytest.skip("NVFP4 not supported for grouped linear")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear( te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda() ).cuda()
...@@ -512,7 +509,7 @@ def test_sanity_grouped_linear( ...@@ -512,7 +509,7 @@ def test_sanity_grouped_linear(
elif empty_split == "middle": elif empty_split == "middle":
m_splits[num_gemms // 2] = 0 m_splits[num_gemms // 2] = 0
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=use_fp8, recipe=fp8_recipe):
out = te_grouped_linear(inp_hidden_states, m_splits) out = te_grouped_linear(inp_hidden_states, m_splits)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -976,9 +973,9 @@ def test_replace_raw_data_for_float8tensor(): ...@@ -976,9 +973,9 @@ def test_replace_raw_data_for_float8tensor():
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val(): def test_quantized_model_init_high_precision_init_val():
"""Test fp8_model_init with preserve_high_precision_init_val=True""" """Test quantized_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True): with quantized_model_init(preserve_high_precision_init_val=True):
model = Linear(768, 768) model = Linear(768, 768)
weight = model.weight weight = model.weight
...@@ -1051,7 +1048,7 @@ def test_linear_frozen_weights_memory_default_recipe(): ...@@ -1051,7 +1048,7 @@ def test_linear_frozen_weights_memory_default_recipe():
linear.weight.requires_grad = False linear.weight.requires_grad = False
# Forward and backward pass with FP8 # Forward and backward pass with FP8
with fp8_autocast(): with autocast():
o = linear(x) o = linear(x)
g_o = torch.randn_like(o) g_o = torch.randn_like(o)
...@@ -1105,7 +1102,7 @@ def test_inference_mode( ...@@ -1105,7 +1102,7 @@ def test_inference_mode(
# Construct module # Construct module
module = None module = None
with torch.no_grad(): with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear": if module_name == "Linear":
module = Linear(hidden_size, hidden_size) module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear": elif module_name == "LayerNormLinear":
...@@ -1140,6 +1137,6 @@ def test_inference_mode( ...@@ -1140,6 +1137,6 @@ def test_inference_mode(
kwargs = {} kwargs = {}
if module_name == "GroupedLinear": if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length] kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): with autocast(enabled=with_quantization, recipe=quantization_recipe):
y = module(x, **kwargs) y = module(x, **kwargs)
check_weights() check_weights()
...@@ -7,14 +7,14 @@ from __future__ import annotations ...@@ -7,14 +7,14 @@ from __future__ import annotations
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
import pytest
import torch import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import InferenceParams
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend, get_attention_backend,
......
...@@ -161,7 +161,7 @@ class DelayedScaling(Recipe): ...@@ -161,7 +161,7 @@ class DelayedScaling(Recipe):
where `Tensor` is a framework tensor type. where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True` reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8 By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every distributed group. If set to `False`, this reduction is skipped and every
GPU maintains local amaxes and scaling factors. To ensure results are GPU maintains local amaxes and scaling factors. To ensure results are
...@@ -169,7 +169,7 @@ class DelayedScaling(Recipe): ...@@ -169,7 +169,7 @@ class DelayedScaling(Recipe):
ranks must checkpoint in order to store the local tensors. ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False` fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the `autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend. `FusedAttention` backend.
......
...@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format ...@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute from transformer_engine.pytorch.quantization import _default_sf_compute
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
......
...@@ -34,7 +34,7 @@ load_framework_extension("jax") ...@@ -34,7 +34,7 @@ load_framework_extension("jax")
from . import flax from . import flax
from . import quantize from . import quantize
from .quantize import fp8_autocast, update_collections from .quantize import autocast, fp8_autocast, update_collections
from .quantize import NVTE_FP8_COLLECTION_NAME from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
...@@ -45,6 +45,7 @@ from ..common.utils import DeprecatedEnum ...@@ -45,6 +45,7 @@ from ..common.utils import DeprecatedEnum
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
"autocast",
"fp8_autocast", "fp8_autocast",
"update_collections", "update_collections",
"MeshResource", "MeshResource",
......
...@@ -66,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -66,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
for 1D-sharding tensor parallelism. for 1D-sharding tensor parallelism.
.. warning:: .. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function. Please make sure ShardingResource is set via autocast before calling this function.
.. note:: .. note::
This function is only needed when using TransformerLayer. For other modules, such as This function is only needed when using TransformerLayer. For other modules, such as
......
...@@ -7,6 +7,7 @@ Config module for quantization metadata management ...@@ -7,6 +7,7 @@ Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes. in JAX, including support for different scaling modes and datatypes.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -23,7 +24,14 @@ import jax.numpy as jnp ...@@ -23,7 +24,14 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine.common import recipe from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
global_shard_guard, global_shard_guard,
MeshResource, MeshResource,
...@@ -39,6 +47,7 @@ from .device_utils import get_device_compute_capability ...@@ -39,6 +47,7 @@ from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"get_quantize_config", "get_quantize_config",
"get_quantize_config_with_recipe", "get_quantize_config_with_recipe",
"autocast",
"fp8_autocast", "fp8_autocast",
"is_fp8_available", "is_fp8_available",
"is_scaling_mode_supported", "is_scaling_mode_supported",
...@@ -51,8 +60,6 @@ __all__ = [ ...@@ -51,8 +60,6 @@ __all__ = [
"TensorSource", "TensorSource",
] ]
_is_fp8_available = None
_reason_for_no_fp8 = ""
_is_scaling_mode_supported = None _is_scaling_mode_supported = None
_reason_for_no_scaling_mode = "" _reason_for_no_scaling_mode = ""
Collection = Union[Dict, FrozenDict] Collection = Union[Dict, FrozenDict]
...@@ -195,22 +202,22 @@ def get_supported_scaling_modes() -> List[ScalingMode]: ...@@ -195,22 +202,22 @@ def get_supported_scaling_modes() -> List[ScalingMode]:
] ]
def get_supported_quantization_recipes() -> List[recipe.Recipe]: def get_supported_quantization_recipes() -> List[Recipe]:
"""Get all supported quantization recipes.""" """Get all supported quantization recipes."""
# We don't support all the recipes TE/Common supports yet # We don't support all the recipes TE/Common supports yet
# return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()] # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()]
all_recipes = [ all_recipes = [
recipe.DelayedScaling(), DelayedScaling(),
recipe.Float8CurrentScaling(), Float8CurrentScaling(),
recipe.MXFP8BlockScaling(), MXFP8BlockScaling(),
recipe.NVFP4BlockScaling(), NVFP4BlockScaling(),
] ]
return [ return [
recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0] recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0]
] ]
def _format2dtypes(format_: recipe.Format): def _format2dtypes(format_: Format):
"""Convert recipe.Format.dtype to corresponding JAX dtypes. """Convert recipe.Format.dtype to corresponding JAX dtypes.
Args: Args:
...@@ -219,13 +226,13 @@ def _format2dtypes(format_: recipe.Format): ...@@ -219,13 +226,13 @@ def _format2dtypes(format_: recipe.Format):
Returns: Returns:
A tuple of (forward_dtype, backward_dtype) for the given format A tuple of (forward_dtype, backward_dtype) for the given format
""" """
if format_ == recipe.Format.E4M3: if format_ == Format.E4M3:
return jnp.float8_e4m3fn, jnp.float8_e4m3fn return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == recipe.Format.E5M2: if format_ == Format.E5M2:
return jnp.float8_e5m2, jnp.float8_e5m2 return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == recipe.Format.HYBRID: if format_ == Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2 return jnp.float8_e4m3fn, jnp.float8_e5m2
if format_ == recipe.Format.E2M1: if format_ == Format.E2M1:
return jnp.float4_e2m1fn, jnp.float4_e2m1fn return jnp.float4_e2m1fn, jnp.float4_e2m1fn
return jnp.bfloat16, jnp.bfloat16 return jnp.bfloat16, jnp.bfloat16
...@@ -289,7 +296,7 @@ class BaseQuantizeConfig(ABC): ...@@ -289,7 +296,7 @@ class BaseQuantizeConfig(ABC):
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize the quantization configuration from a given recipe. """Initialize the quantization configuration from a given recipe.
Args: Args:
...@@ -359,7 +366,7 @@ class BaseQuantizeConfig(ABC): ...@@ -359,7 +366,7 @@ class BaseQuantizeConfig(ABC):
class NoOpQuantizeConfig(BaseQuantizeConfig): class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Configuration class higher-precision non-quantized operation.""" """Configuration class higher-precision non-quantized operation."""
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize no-op configuration.""" """Initialize no-op configuration."""
raise NotImplementedError( raise NotImplementedError(
"NoOpQuantizeConfig cannot be initialize from a recipe as it represents" "NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
...@@ -399,7 +406,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -399,7 +406,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize delayed scaling FP8 configuration. """Initialize delayed scaling FP8 configuration.
Args: Args:
...@@ -477,7 +484,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -477,7 +484,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize current scaling FP8 configuration. """Initialize current scaling FP8 configuration.
Args: Args:
...@@ -519,7 +526,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -519,7 +526,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling FP8 configuration.
Args: Args:
...@@ -560,7 +567,7 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -560,7 +567,7 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
This class provides specific initialization and finalization for NVFP4 scaling quantization mode. This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling FP8 configuration.
Args: Args:
...@@ -622,12 +629,12 @@ _QUANTIZE_CONFIG = NoOpQuantizeConfig() ...@@ -622,12 +629,12 @@ _QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config(): def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by fp8_autocast context.""" """Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG return _QUANTIZE_CONFIG
def get_quantize_config_class( def get_quantize_config_class(
fp8_recipe: recipe.Recipe, fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]: ) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration class based on the FP8 recipe. """Get the quantization configuration class based on the FP8 recipe.
...@@ -636,18 +643,18 @@ def get_quantize_config_class( ...@@ -636,18 +643,18 @@ def get_quantize_config_class(
Returns: Returns:
The quantization config class corresponding to the given recipe. The quantization config class corresponding to the given recipe.
""" """
if isinstance(fp8_recipe, recipe.DelayedScaling): if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, MXFP8BlockScaling):
return BlockScalingQuantizeConfig return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling): if isinstance(fp8_recipe, Float8CurrentScaling):
return CurrentScalingQuantizeConfig return CurrentScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.NVFP4BlockScaling): if isinstance(fp8_recipe, NVFP4BlockScaling):
return NVFP4ScalingQuantizeConfig return NVFP4ScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe.""" """Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)() config = get_quantize_config_class(fp8_recipe)()
config.initialize_from_recipe(fp8_recipe) config.initialize_from_recipe(fp8_recipe)
...@@ -655,14 +662,14 @@ def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): ...@@ -655,14 +662,14 @@ def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe):
@contextmanager @contextmanager
def fp8_autocast( def autocast(
enabled: bool = False, enabled: bool = False,
fp8_recipe: Optional[recipe.Recipe] = None, recipe: Optional[Recipe] = None,
mesh_resource: Optional[MeshResource] = None, mesh_resource: Optional[MeshResource] = None,
) -> None: ) -> None:
r"""Context manager for FP8 automatic mixed precision. r"""Context manager for FP8 or FP4 usage.
This context manager enables FP8 quantization for the duration of its context. This context manager enables quantization for the duration of its context.
.. code-block:: python .. code-block:: python
mesh_shape = (4, 2) mesh_shape = (4, 2)
...@@ -673,7 +680,7 @@ def fp8_autocast( ...@@ -673,7 +680,7 @@ def fp8_autocast(
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource): with autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple()) rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer() transformer = TransformerLayer()
...@@ -690,15 +697,15 @@ def fp8_autocast( ...@@ -690,15 +697,15 @@ def fp8_autocast(
---------- ----------
enabled: bool, default = False enabled: bool, default = False
Whether or not to enable fp8 Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training. recipe used for low precision quantization.
mesh_resource: MeshResource, default = None mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along. Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used. If set to None, then no data or tensor parallelism will be used.
""" """
if fp8_recipe is None: if recipe is None:
fp8_recipe = recipe.DelayedScaling() recipe = DelayedScaling()
global _QUANTIZE_CONFIG global _QUANTIZE_CONFIG
...@@ -709,15 +716,45 @@ def fp8_autocast( ...@@ -709,15 +716,45 @@ def fp8_autocast(
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() _QUANTIZE_CONFIG = get_quantize_config_class(recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported() is_supported, reason = _QUANTIZE_CONFIG.is_supported()
assert is_supported, reason assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) _QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield yield
finally: finally:
_QUANTIZE_CONFIG = old_quantize_config _QUANTIZE_CONFIG = old_quantize_config
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[Recipe] = None,
mesh_resource: Optional[MeshResource] = None,
) -> None:
"""
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.
"""
warnings.warn(
"fp8_autocast is deprecated and will be removed in a future release. "
"Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.",
category=DeprecationWarning,
stacklevel=2,
)
# Call new implementation.
with autocast(
enabled=enabled,
recipe=fp8_recipe,
mesh_resource=mesh_resource,
):
yield
def update_collections(new: Collection, original: Collection) -> Collection: def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure. r"""Update collections with new values while preserving original structure.
......
...@@ -46,8 +46,18 @@ from transformer_engine.pytorch.permutation import ( ...@@ -46,8 +46,18 @@ from transformer_engine.pytorch.permutation import (
moe_sort_chunks_by_index, moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs, moe_sort_chunks_by_index_with_probs,
) )
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.quantization import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.quantization import fp8_model_init
from transformer_engine.pytorch.quantization import autocast
from transformer_engine.pytorch.quantization import quantized_model_init
from transformer_engine.pytorch.quantization import is_fp8_available
from transformer_engine.pytorch.quantization import is_mxfp8_available
from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available
from transformer_engine.pytorch.quantization import is_nvfp4_available
from transformer_engine.pytorch.quantization import get_default_recipe
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import is_bf16_available
from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
...@@ -61,14 +71,17 @@ from transformer_engine.pytorch.tensor import Float8Quantizer ...@@ -61,14 +71,17 @@ from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor import Float8BlockQuantizer from transformer_engine.pytorch.tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor import NVFP4Quantizer
from transformer_engine.pytorch.tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor import QuantizedTensorStorage
from transformer_engine.pytorch.tensor import Float8TensorStorage from transformer_engine.pytorch.tensor import Float8TensorStorage
from transformer_engine.pytorch.tensor import MXFP8TensorStorage from transformer_engine.pytorch.tensor import MXFP8TensorStorage
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage
from transformer_engine.pytorch.tensor import NVFP4TensorStorage
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import Float8Tensor from transformer_engine.pytorch.tensor import Float8Tensor
from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import prepare_for_saving from transformer_engine.pytorch.tensor import prepare_for_saving
from transformer_engine.pytorch.tensor import restore_from_saved from transformer_engine.pytorch.tensor import restore_from_saved
......
...@@ -42,7 +42,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -42,7 +42,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O, META_O,
META_QKV, META_QKV,
) )
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
...@@ -1074,7 +1074,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1074,7 +1074,7 @@ class FusedAttnFunc(torch.autograd.Function):
nvtx_label = "transformer_engine.FusedAttnFunc.forward" nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}") nvtx_range_push(f"{nvtx_label}")
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"] # may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
......
...@@ -19,7 +19,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -19,7 +19,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd, fused_attn_bwd,
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser
...@@ -1164,7 +1164,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1164,7 +1164,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
is_input_fp8 = isinstance(q, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output is_output_fp8 = fp8_output
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"] # may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
...@@ -3151,7 +3151,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3151,7 +3151,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
is_input_fp8 = isinstance(q, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output is_output_fp8 = fp8_output
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"] # may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
......
...@@ -21,7 +21,7 @@ from transformer_engine.common.recipe import ( ...@@ -21,7 +21,7 @@ from transformer_engine.common.recipe import (
Float8CurrentScaling, Float8CurrentScaling,
) )
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.quantization import (
get_fp8_te_dtype, get_fp8_te_dtype,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
...@@ -91,26 +91,26 @@ _alibi_cache = { ...@@ -91,26 +91,26 @@ _alibi_cache = {
This feature is **experimental** and subject to change. This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this, Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer, users can either use multiple, nested autocast() contexts to assign a distinct recipe for each layer,
or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention or use a single autocast() for the non-attention layers and configure the recipe for the attention
layers as follows. layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration | | Linear | Attention | Configuration |
+===================+===========+===================================================================================+ +===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); | | FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" | | | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); | | FP8DS | FP8DS | Pass FP8DS to autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); | | FP8CS | FP8DS | Pass FP8CS to autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); | | NVFP4 | FP8DS | Pass NVFP4 to autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
...@@ -118,19 +118,19 @@ layers as follows. ...@@ -118,19 +118,19 @@ layers as follows.
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); | | FP8DS | FP8CS | Pass FP8DS to autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); | | FP8CS | FP8CS | Pass FP8CS to autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+ +-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); | | NVFP4 | FP8CS | Pass NVFP4 to autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
...@@ -544,7 +544,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -544,7 +544,7 @@ class DotProductAttention(TransformerEngineBaseModule):
""" """
_original_recipe = self.fp8_meta.get("recipe", None) _original_recipe = self.fp8_meta.get("recipe", None)
# global recipe set in fp8_autocast() # global recipe set in autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.custom(): if fp8_recipe.custom():
return return
...@@ -560,7 +560,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -560,7 +560,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_recipe_dpa = fp8_recipe fp8_recipe_dpa = fp8_recipe
fp8_recipes = fp8_recipe fp8_recipes = fp8_recipe
if _dpa_fp8_recipe == "F16": if _dpa_fp8_recipe == "F16":
# ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe.fp8_dpa = False fp8_recipe.fp8_dpa = False
fp8_recipe.fp8_mha = False fp8_recipe.fp8_mha = False
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling":
......
...@@ -40,7 +40,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -40,7 +40,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.quantization import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
...@@ -222,7 +222,7 @@ class AttentionParams: ...@@ -222,7 +222,7 @@ class AttentionParams:
is_training: bool, default = `True` is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`) Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False` fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region. Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None` fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = `None`
......
...@@ -9,7 +9,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -9,7 +9,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
...@@ -33,7 +33,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc ...@@ -33,7 +33,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast(). # Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" # Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. # and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
......
...@@ -36,7 +36,7 @@ from .utils import ( ...@@ -36,7 +36,7 @@ from .utils import (
needs_quantized_gemm, needs_quantized_gemm,
) )
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .quantization import FP8GlobalStateManager, autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer
...@@ -419,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -419,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True activation_recompute=True, recompute_phase=True
), fp8_autocast( ), autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe enabled=ctx.fp8, recipe=ctx.fp8_recipe
): ):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
...@@ -754,8 +754,8 @@ def checkpoint( ...@@ -754,8 +754,8 @@ def checkpoint(
def recompute_fn(*args, **kwargs): def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), ( with torch.autograd.enable_grad(), (
te_recompute_ctx te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast( ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast(
enabled=fp8, fp8_recipe=fp8_recipe enabled=fp8, recipe=fp8_recipe
): ):
function(*args, **kwargs) function(*args, **kwargs)
...@@ -1969,7 +1969,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -1969,7 +1969,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if hasattr(fsdp_root, "primary_weights_in_fp8"): if hasattr(fsdp_root, "primary_weights_in_fp8"):
assert not fsdp_root.primary_weights_in_fp8, ( assert not fsdp_root.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context." "Please initialize your model without the te.quantized_model_init(...) context."
) )
root_state = _get_module_fsdp_state(fsdp_root) root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState." assert root_state is not None, "Root module does not have a valid _FSDPState."
...@@ -1982,7 +1982,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -1982,7 +1982,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if hasattr(fsdp_module.module, "primary_weights_in_fp8"): if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
assert not fsdp_module.module.primary_weights_in_fp8, ( assert not fsdp_module.module.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context." "Please initialize your model without the te.quantized_model_init(...) context."
) )
setattr(fsdp_module.module, "fsdp_group", state.process_group) setattr(fsdp_module.module, "fsdp_group", state.process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment