Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
......@@ -5,7 +5,6 @@
"""Unit tests for context parallel utils."""
import torch
import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp,
......
......@@ -14,20 +14,22 @@ import pytest
import torch
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
from transformer_engine.pytorch import (
make_graphed_callables,
autocast,
quantized_model_init,
TransformerLayer,
DotProductAttention,
InferenceParams,
is_bf16_available,
)
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
_current_file = pathlib.Path(__file__).resolve()
......@@ -42,7 +44,7 @@ from utils import (
reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
if is_bf16_available():
param_types.append(torch.bfloat16)
model_configs_infer = {
......@@ -238,7 +240,7 @@ def get_model(
if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
TransformerLayer(
hidden_size=hidden_size,
......@@ -261,7 +263,7 @@ def get_model(
for layer_number in range(1, num_layers + 1)
]
if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
DotProductAttention(
kv_channels=config.head_dim_qk,
......@@ -559,9 +561,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
model[i],
sample_args,
num_warmup_iters=10,
fp8_enabled=is_fp8,
enabled=is_fp8,
sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
)
for i in range(num_layers)
]
......@@ -654,7 +656,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=is_fp8, recipe=fp8_recipe):
for m in model:
incremental_output = m(
*incremental_output,
......
......@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import is_fp8_available
from test_numerics import (
_emulate_linear,
......@@ -45,7 +45,7 @@ FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_available = is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
......@@ -117,7 +117,7 @@ class AllGather(torch.autograd.Function):
def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x)
y.requires_grad_(True)
......@@ -413,13 +413,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
......@@ -532,7 +532,7 @@ def test_per_tensor_scaling(
LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x)
model.zero_grad()
if parallel_mode == "column":
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
import pathlib
from nvdlfw_inspect.config_manager import ConfigManager
......
......@@ -8,18 +8,22 @@ import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
LOG_QUANTIZED_CONFIG_BASE = """
......@@ -128,7 +132,7 @@ def test_sanity(feature_dirs):
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
with te.autocast(recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
......@@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
with te.autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
......
......@@ -17,19 +17,19 @@ import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.quantization import _default_sf_compute
from transformer_engine.pytorch import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
is_fp8_available,
)
from transformer_engine.pytorch.module.base import (
_2X_ACC_DGRAD,
_2X_ACC_FPROP,
_2X_ACC_WGRAD,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID
......@@ -250,7 +250,7 @@ def _init_model(weight):
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True):
with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE):
with tepytorch.autocast(enabled=fp8, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward()
debug_api.step()
......@@ -547,7 +547,7 @@ def run_per_tensor_scaling(
LOSS_MULTIPLIER = 100
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with tepytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=True)
model.zero_grad()
y.retain_grad()
......
......@@ -7,11 +7,10 @@ import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import create_config_file
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
B, S, H, D = 64, 64, 64, 64
......@@ -68,7 +67,7 @@ def _get_model(model_key):
def _run_forward_backward(model, fp8):
for _ in range(3):
inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8):
with te.autocast(enabled=fp8):
out = model(inp)
out.sum().backward()
debug_api.step()
......
......@@ -21,13 +21,13 @@ from transformer_engine.common.recipe import (
Recipe,
)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
QuantizedTensor,
Float8Tensor,
Float8CurrentScalingQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor):
......@@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
......@@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
......@@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
with te.quantized_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
......@@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
with te.autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
recipe=quantization_recipe(quantization),
amax_reduction_group=mock_group,
):
y = model(x)
......
......@@ -110,9 +110,9 @@ def _train(args):
build_model_context = nullcontext
build_model_context_args = {}
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch import quantized_model_init
build_model_context = fp8_model_init
build_model_context = quantized_model_init
build_model_context_args["enabled"] = True
# Build the model with the specified context
......
......@@ -18,9 +18,12 @@ import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
Float8Tensor,
Float8Quantizer,
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
......@@ -171,12 +174,12 @@ def _parse_args(argv=None, namespace=None):
opts.p2p = True
if opts.atomic:
if not te.fp8.check_fp8_support():
if not te.is_fp8_available():
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
assert opts.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts
......
......@@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--quantization",
......@@ -438,7 +438,7 @@ def _train(opts):
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
)
with te.fp8_model_init(enabled=opts.fp8_init):
with te.quantized_model_init(enabled=opts.fp8_init):
test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs)
dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0:
......@@ -450,7 +450,7 @@ def _train(opts):
ref_args, ref_kwargs, _ = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
)
with te.fp8_model_init(enabled=opts.fp8_init):
with te.quantized_model_init(enabled=opts.fp8_init):
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
......@@ -473,7 +473,9 @@ def _train(opts):
layer_contexts = [
(
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
partial(
te.autocast, enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world
)
if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext
......
......@@ -26,8 +26,7 @@ from transformer_engine.common.recipe import (
Recipe,
QParams,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors
......@@ -75,7 +74,7 @@ def quantization_recipe() -> Recipe:
return Float8BlockScaling()
if QUANTIZATION == "nvfp4":
return nvfp4_vanilla()
return te.fp8.get_default_fp8_recipe()
return te.quantization.get_default_fp8_recipe()
def main(argv=None, namespace=None):
......@@ -316,15 +315,15 @@ def _apply_models(
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node.requires_grad_()
input_distributed.requires_grad_()
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
recipe=quantization_recipe(),
):
output_single_node = model_single_node(input_single_node, **kwargs)
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
fp8_group=NCCL_WORLD,
recipe=quantization_recipe(),
amax_reduction_group=NCCL_WORLD,
):
output_distributed = model_distributed(input_distributed, **kwargs)
return output_single_node, output_distributed
......
......@@ -9,21 +9,18 @@ import datetime
import os
import sys
from functools import wraps
import math
import transformer_engine.pytorch as te
import torch
from torch import nn
import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
NVFP4BlockScaling,
Format,
Recipe,
QParams,
CustomRecipe,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
......@@ -506,7 +503,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
)
# run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.autocast(enabled=True, recipe=recipe):
y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear(
x,
w,
......@@ -524,7 +521,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# run the reference
reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe):
with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
x,
w,
......@@ -700,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
)
# run the recipe under test
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with te.autocast(enabled=True, recipe=recipe):
y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
......@@ -717,7 +714,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# run the reference
reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe):
with te.autocast(enabled=True, recipe=reference_recipe):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
......
......@@ -8,15 +8,15 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available
if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
TEST_ROOT = Path(__file__).parent.resolve()
......
......@@ -9,13 +9,12 @@ import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
......
......@@ -20,16 +20,15 @@ import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
QuantizedTensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
......@@ -39,9 +38,9 @@ from utils import dtype_tols, make_recipe, quantization_tols
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
......@@ -427,7 +426,7 @@ def _test_basic_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
......@@ -440,7 +439,7 @@ def _test_basic_linear(
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -593,7 +592,7 @@ def _test_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -612,7 +611,7 @@ def _test_linear(
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -759,7 +758,7 @@ def _test_mlp(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.GELU(),
te_ops.Linear(
......@@ -795,7 +794,7 @@ def _test_mlp(
# Warmup steps
for _ in range(3):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
x_test.grad = None
......@@ -806,7 +805,7 @@ def _test_mlp(
model[3].bias.grad = None
# Forward and backward step
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -944,7 +943,7 @@ def _test_fp8_scale_update(
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -1004,7 +1003,7 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_linear with {config=}")
quantization, tensor_parallel_mode, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_linear(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
......@@ -1018,7 +1017,7 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_mlp with {config=}")
quantization, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
dtype = torch.bfloat16 if is_bf16_available() else torch.float32
_test_mlp(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
......
......@@ -16,23 +16,23 @@ import sys
import pytest
import torch
from typing import Optional, Iterable
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
QuantizedTensor,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
......@@ -40,8 +40,8 @@ sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
......@@ -301,7 +301,7 @@ def _test_linear(
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
ops = []
linear_op = None
bias_op = None
......@@ -351,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......
......@@ -8,7 +8,7 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
"""
Distributed numerics tests
......@@ -26,12 +26,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......
......@@ -8,7 +8,7 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
"""
Distributed numerics tests
......@@ -23,12 +23,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment