Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Unit tests for context parallel utils.""" """Unit tests for context parallel utils."""
import torch import torch
import unittest import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank, get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp, pad_thd_sequences_for_cp,
......
...@@ -14,20 +14,22 @@ import pytest ...@@ -14,20 +14,22 @@ import pytest
import torch import torch
from torch.distributions import Exponential from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch import (
from transformer_engine.common import recipe make_graphed_callables,
from transformer_engine.pytorch import fp8_autocast, fp8_model_init autocast,
from transformer_engine.pytorch.transformer import ( quantized_model_init,
TransformerLayer, TransformerLayer,
DotProductAttention,
InferenceParams,
is_bf16_available,
) )
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils, FlashAttentionUtils as fa_utils,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible,
) )
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
...@@ -42,7 +44,7 @@ from utils import ( ...@@ -42,7 +44,7 @@ from utils import (
reset_rng_states() reset_rng_states()
param_types = [torch.float16] param_types = [torch.float16]
if is_bf16_compatible(): if is_bf16_available():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
model_configs_infer = { model_configs_infer = {
...@@ -238,7 +240,7 @@ def get_model( ...@@ -238,7 +240,7 @@ def get_model(
if module == "TransformerLayer": if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [ model = [
TransformerLayer( TransformerLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -261,7 +263,7 @@ def get_model( ...@@ -261,7 +263,7 @@ def get_model(
for layer_number in range(1, num_layers + 1) for layer_number in range(1, num_layers + 1)
] ]
if module == "DotProductAttention": if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [ model = [
DotProductAttention( DotProductAttention(
kv_channels=config.head_dim_qk, kv_channels=config.head_dim_qk,
...@@ -559,9 +561,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -559,9 +561,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
model[i], model[i],
sample_args, sample_args,
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=is_fp8, enabled=is_fp8,
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe, recipe=fp8_recipe,
) )
for i in range(num_layers) for i in range(num_layers)
] ]
...@@ -654,7 +656,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -654,7 +656,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
if inference_params.is_paged: if inference_params.is_paged:
inference_params.cache_manager.print_cache() inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): with autocast(enabled=is_fp8, recipe=fp8_recipe):
for m in model: for m in model:
incremental_output = m( incremental_output = m(
*incremental_output, *incremental_output,
......
...@@ -16,7 +16,7 @@ import transformer_engine ...@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import is_fp8_available
from test_numerics import ( from test_numerics import (
_emulate_linear, _emulate_linear,
...@@ -45,7 +45,7 @@ FEATURE_DIRS = None ...@@ -45,7 +45,7 @@ FEATURE_DIRS = None
all_boolean = [True, False] all_boolean = [True, False]
TEST_NR = 0 TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available = is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
...@@ -117,7 +117,7 @@ class AllGather(torch.autograd.Function): ...@@ -117,7 +117,7 @@ class AllGather(torch.autograd.Function):
def _run_forward_backward(x, model, parallel_mode=None, group=None): def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x) y = model(x)
y.requires_grad_(True) y.requires_grad_(True)
...@@ -413,13 +413,13 @@ def test_log_expert_parallel(**kwargs): ...@@ -413,13 +413,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel ) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1") model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2") model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y1 = model(x) y1 = model(x)
y2 = model1(x) y2 = model1(x)
y = y1 + y2 y = y1 + y2
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y = model(x) y = model(x)
if WORLD_RANK != 0: if WORLD_RANK != 0:
y = y + model1(x) y = y + model1(x)
...@@ -532,7 +532,7 @@ def test_per_tensor_scaling( ...@@ -532,7 +532,7 @@ def test_per_tensor_scaling(
LOSS_MULTIPLIER = 100 LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x) y = model(x)
model.zero_grad() model.zero_grad()
if parallel_mode == "column": if parallel_mode == "column":
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
import torch import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
import pathlib, os import pathlib
from nvdlfw_inspect.config_manager import ConfigManager from nvdlfw_inspect.config_manager import ConfigManager
......
...@@ -8,18 +8,22 @@ import transformer_engine.pytorch as te ...@@ -8,18 +8,22 @@ import transformer_engine.pytorch as te
import torch import torch
import tempfile import tempfile
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest import pytest
import contextlib import contextlib
import os import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
LOG_QUANTIZED_CONFIG_BASE = """ LOG_QUANTIZED_CONFIG_BASE = """
...@@ -128,7 +132,7 @@ def test_sanity(feature_dirs): ...@@ -128,7 +132,7 @@ def test_sanity(feature_dirs):
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10): for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): with te.autocast(recipe=recipe.DelayedScaling()):
output = model(inp) output = model(inp)
loss = output.sum() loss = output.sum()
loss.backward() loss.backward()
...@@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
for i in range(20): for i in range(20):
x = torch.randn(4, 128, 128).cuda() x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True): with te.autocast(enabled=True):
y = model(x) y = model(x)
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
......
...@@ -17,19 +17,19 @@ import transformer_engine.debug ...@@ -17,19 +17,19 @@ import transformer_engine.debug
import transformer_engine.pytorch as tepytorch import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.fp8 import _default_sf_compute from transformer_engine.pytorch.quantization import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
is_fp8_available,
) )
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.module.base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
all_boolean = [True, False] all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID FP8_FORMAT = Format.HYBRID
...@@ -250,7 +250,7 @@ def _init_model(weight): ...@@ -250,7 +250,7 @@ def _init_model(weight):
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True): def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True):
with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE): with tepytorch.autocast(enabled=fp8, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch) y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward() (y.sum() * loss_scale).backward()
debug_api.step() debug_api.step()
...@@ -547,7 +547,7 @@ def run_per_tensor_scaling( ...@@ -547,7 +547,7 @@ def run_per_tensor_scaling(
LOSS_MULTIPLIER = 100 LOSS_MULTIPLIER = 100
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with tepytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=True) y = model(x, is_first_microbatch=True)
model.zero_grad() model.zero_grad()
y.retain_grad() y.retain_grad()
......
...@@ -7,11 +7,10 @@ import torch ...@@ -7,11 +7,10 @@ import torch
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import create_config_file from test_numerics import create_config_file
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
B, S, H, D = 64, 64, 64, 64 B, S, H, D = 64, 64, 64, 64
...@@ -68,7 +67,7 @@ def _get_model(model_key): ...@@ -68,7 +67,7 @@ def _get_model(model_key):
def _run_forward_backward(model, fp8): def _run_forward_backward(model, fp8):
for _ in range(3): for _ in range(3):
inp = torch.randn((S, B, H)).cuda() inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8): with te.autocast(enabled=fp8):
out = model(inp) out = model(inp)
out.sum().backward() out.sum().backward()
debug_api.step() debug_api.step()
......
...@@ -21,13 +21,13 @@ from transformer_engine.common.recipe import ( ...@@ -21,13 +21,13 @@ from transformer_engine.common.recipe import (
Recipe, Recipe,
) )
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_tensor import ( QuantizedTensor,
Float8Tensor, Float8Tensor,
Float8CurrentScalingQuantizer, Float8BlockwiseQTensor,
) )
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor): def _get_raw_data(quantized_tensor):
...@@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
} }
# Create model with FP8 weights # Create model with FP8 weights
with te.fp8.fp8_model_init( with te.quantized_model_init(
enabled=quantization is not None, enabled=quantization is not None,
recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True, preserve_high_precision_init_val=True,
...@@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different. # Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank] x = inputs[rank]
with te.fp8.fp8_autocast( with te.autocast(
enabled=quantization is not None, enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
fp8_group=mock_group, amax_reduction_group=mock_group,
): ):
y_fp8 = model_fp8(x) y_fp8 = model_fp8(x)
with te.fp8_autocast( with te.autocast(
enabled=quantization is not None, enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
fp8_group=mock_group, amax_reduction_group=mock_group,
): ):
y = model(x) y = model(x)
...@@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
# Create model with FP8 weights # Create model with FP8 weights
with te.fp8.fp8_model_init( with te.quantized_model_init(
enabled=quantization is not None, enabled=quantization is not None,
recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True, preserve_high_precision_init_val=True,
...@@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different. # Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank] x = inputs[rank]
with te.fp8.fp8_autocast( with te.autocast(
enabled=quantization is not None, enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
fp8_group=mock_group, amax_reduction_group=mock_group,
): ):
y_fp8 = model_fp8(x) y_fp8 = model_fp8(x)
with te.fp8_autocast( with te.autocast(
enabled=quantization is not None, enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization), recipe=quantization_recipe(quantization),
fp8_group=mock_group, amax_reduction_group=mock_group,
): ):
y = model(x) y = model(x)
......
...@@ -110,9 +110,9 @@ def _train(args): ...@@ -110,9 +110,9 @@ def _train(args):
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} build_model_context_args = {}
from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch import quantized_model_init
build_model_context = fp8_model_init build_model_context = quantized_model_init
build_model_context_args["enabled"] = True build_model_context_args["enabled"] = True
# Build the model with the specified context # Build the model with the specified context
......
...@@ -18,9 +18,12 @@ import torch.distributed as dist ...@@ -18,9 +18,12 @@ import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
Float8Tensor,
Float8Quantizer,
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes, get_cublas_workspace_size_bytes,
...@@ -171,12 +174,12 @@ def _parse_args(argv=None, namespace=None): ...@@ -171,12 +174,12 @@ def _parse_args(argv=None, namespace=None):
opts.p2p = True opts.p2p = True
if opts.atomic: if opts.atomic:
if not te.fp8.check_fp8_support(): if not te.is_fp8_available():
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8." assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8" opts.quantization = "fp8"
if opts.fp8_output: if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute." assert opts.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts return opts
......
...@@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None):
) )
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
) )
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
...@@ -438,7 +438,7 @@ def _train(opts): ...@@ -438,7 +438,7 @@ def _train(opts):
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
) )
with te.fp8_model_init(enabled=opts.fp8_init): with te.quantized_model_init(enabled=opts.fp8_init):
test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs) test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs)
dist_print("Initialized test model...", debug=True) dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0: if WORLD_RANK == 0:
...@@ -450,7 +450,7 @@ def _train(opts): ...@@ -450,7 +450,7 @@ def _train(opts):
ref_args, ref_kwargs, _ = _get_layer_args( ref_args, ref_kwargs, _ = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
) )
with te.fp8_model_init(enabled=opts.fp8_init): with te.quantized_model_init(enabled=opts.fp8_init):
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs) ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True) dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
...@@ -473,7 +473,9 @@ def _train(opts): ...@@ -473,7 +473,9 @@ def _train(opts):
layer_contexts = [ layer_contexts = [
( (
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) partial(
te.autocast, enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world
)
if opts.num_layers_at_start_in_bf16 <= i if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext else nullcontext
......
...@@ -26,8 +26,7 @@ from transformer_engine.common.recipe import ( ...@@ -26,8 +26,7 @@ from transformer_engine.common.recipe import (
Recipe, Recipe,
QParams, QParams,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
...@@ -75,7 +74,7 @@ def quantization_recipe() -> Recipe: ...@@ -75,7 +74,7 @@ def quantization_recipe() -> Recipe:
return Float8BlockScaling() return Float8BlockScaling()
if QUANTIZATION == "nvfp4": if QUANTIZATION == "nvfp4":
return nvfp4_vanilla() return nvfp4_vanilla()
return te.fp8.get_default_fp8_recipe() return te.quantization.get_default_fp8_recipe()
def main(argv=None, namespace=None): def main(argv=None, namespace=None):
...@@ -316,15 +315,15 @@ def _apply_models( ...@@ -316,15 +315,15 @@ def _apply_models(
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node.requires_grad_() input_single_node.requires_grad_()
input_distributed.requires_grad_() input_distributed.requires_grad_()
with te.fp8_autocast( with te.autocast(
enabled=QUANTIZATION is not None, enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(), recipe=quantization_recipe(),
): ):
output_single_node = model_single_node(input_single_node, **kwargs) output_single_node = model_single_node(input_single_node, **kwargs)
with te.fp8_autocast( with te.autocast(
enabled=QUANTIZATION is not None, enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(), recipe=quantization_recipe(),
fp8_group=NCCL_WORLD, amax_reduction_group=NCCL_WORLD,
): ):
output_distributed = model_distributed(input_distributed, **kwargs) output_distributed = model_distributed(input_distributed, **kwargs)
return output_single_node, output_distributed return output_single_node, output_distributed
......
...@@ -9,21 +9,18 @@ import datetime ...@@ -9,21 +9,18 @@ import datetime
import os import os
import sys import sys
from functools import wraps from functools import wraps
import math
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import torch import torch
from torch import nn from torch import nn
import torch.distributed as dist import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
NVFP4BlockScaling, NVFP4BlockScaling,
Format,
Recipe, Recipe,
QParams, QParams,
CustomRecipe, CustomRecipe,
) )
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.experimental import utils
...@@ -506,7 +503,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -506,7 +503,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
) )
# run the recipe under test # run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.autocast(enabled=True, recipe=recipe):
y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear(
x, x,
w, w,
...@@ -524,7 +521,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -524,7 +521,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# run the reference # run the reference
reference_recipe = quantization_reference_recipe() reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
x, x,
w, w,
...@@ -700,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -700,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
) )
# run the recipe under test # run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.autocast(enabled=True, recipe=recipe):
y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x, x,
w, w,
...@@ -717,7 +714,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -717,7 +714,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# run the reference # run the reference
reference_recipe = quantization_reference_recipe() reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
TestDistributedLayerNormLinearBase.run_layernorm_linear( TestDistributedLayerNormLinearBase.run_layernorm_linear(
x, x,
......
...@@ -8,15 +8,15 @@ from pathlib import Path ...@@ -8,15 +8,15 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
......
...@@ -9,13 +9,12 @@ import pytest ...@@ -9,13 +9,12 @@ import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
RNG_SEED: int = 42 RNG_SEED: int = 42
SEQ_LENGTH: int = 1024 SEQ_LENGTH: int = 1024
......
...@@ -20,16 +20,15 @@ import torch ...@@ -20,16 +20,15 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor import QuantizedTensor QuantizedTensor,
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
...@@ -39,9 +38,9 @@ from utils import dtype_tols, make_recipe, quantization_tols ...@@ -39,9 +38,9 @@ from utils import dtype_tols, make_recipe, quantization_tols
# Check what quantization schemes are supported # Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available() nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
...@@ -427,7 +426,7 @@ def _test_basic_linear( ...@@ -427,7 +426,7 @@ def _test_basic_linear(
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear( op = te_ops.BasicLinear(
in_features, in_features,
out_features, out_features,
...@@ -440,7 +439,7 @@ def _test_basic_linear( ...@@ -440,7 +439,7 @@ def _test_basic_linear(
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -593,7 +592,7 @@ def _test_linear( ...@@ -593,7 +592,7 @@ def _test_linear(
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -612,7 +611,7 @@ def _test_linear( ...@@ -612,7 +611,7 @@ def _test_linear(
model[0].bias.copy_(b_test) model[0].bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -759,7 +758,7 @@ def _test_mlp( ...@@ -759,7 +758,7 @@ def _test_mlp(
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.GELU(), te_ops.GELU(),
te_ops.Linear( te_ops.Linear(
...@@ -795,7 +794,7 @@ def _test_mlp( ...@@ -795,7 +794,7 @@ def _test_mlp(
# Warmup steps # Warmup steps
for _ in range(3): for _ in range(3):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
x_test.grad = None x_test.grad = None
...@@ -806,7 +805,7 @@ def _test_mlp( ...@@ -806,7 +805,7 @@ def _test_mlp(
model[3].bias.grad = None model[3].bias.grad = None
# Forward and backward step # Forward and backward step
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -944,7 +943,7 @@ def _test_fp8_scale_update( ...@@ -944,7 +943,7 @@ def _test_fp8_scale_update(
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1004,7 +1003,7 @@ def run_parallel_tests() -> None: ...@@ -1004,7 +1003,7 @@ def run_parallel_tests() -> None:
if rank == 0: if rank == 0:
print(f"Running _test_linear with {config=}") print(f"Running _test_linear with {config=}")
quantization, tensor_parallel_mode, sequence_parallel = config quantization, tensor_parallel_mode, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_linear( _test_linear(
bias=True, # bias=False is tested in _test_basic_linear bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype, dtype=dtype,
...@@ -1018,7 +1017,7 @@ def run_parallel_tests() -> None: ...@@ -1018,7 +1017,7 @@ def run_parallel_tests() -> None:
if rank == 0: if rank == 0:
print(f"Running _test_mlp with {config=}") print(f"Running _test_mlp with {config=}")
quantization, sequence_parallel = config quantization, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_mlp( _test_mlp(
bias=True, # bias=False is tested in _test_basic_linear bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype, dtype=dtype,
......
...@@ -16,23 +16,23 @@ import sys ...@@ -16,23 +16,23 @@ import sys
import pytest import pytest
import torch import torch
from typing import Optional, Iterable
import transformer_engine import transformer_engine
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
MXFP8Quantizer,
QuantizedTensor,
Float8Tensor,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
...@@ -40,8 +40,8 @@ sys.path.append(str(_current_file.parent.parent)) ...@@ -40,8 +40,8 @@ sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe, str_to_dtype from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
...@@ -301,7 +301,7 @@ def _test_linear( ...@@ -301,7 +301,7 @@ def _test_linear(
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
ops = [] ops = []
linear_op = None linear_op = None
bias_op = None bias_op = None
...@@ -351,7 +351,7 @@ def _test_linear( ...@@ -351,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test) bias_op.bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
......
...@@ -8,7 +8,7 @@ from pathlib import Path ...@@ -8,7 +8,7 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te
""" """
Distributed numerics tests Distributed numerics tests
...@@ -26,12 +26,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -26,12 +26,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.") pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count()) NUM_PROCS: int = min(4, torch.cuda.device_count())
......
...@@ -8,7 +8,7 @@ from pathlib import Path ...@@ -8,7 +8,7 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te
""" """
Distributed numerics tests Distributed numerics tests
...@@ -23,12 +23,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -23,12 +23,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.") pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
FP8GlobalStateManager.is_fp8_block_scaling_available() return_reason=True
) )
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count()) NUM_PROCS: int = min(4, torch.cuda.device_count())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment