Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
......@@ -7,8 +7,7 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......
......@@ -6,13 +6,12 @@ import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
import torch
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count()
......@@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
......
......@@ -4,16 +4,15 @@
import pytest
import torch
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def check_nvfp4_gemm_versus_reference(
......
......@@ -4,15 +4,13 @@
import pytest
import torch
import transformer_engine as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.distributed import fp8_autocast
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
class GetRecipes:
......@@ -152,16 +150,16 @@ def check_nvfp4_module_versus_reference(
# Create native module
print("\nCreate native module")
if module_class == te.pytorch.Linear:
native_module = te.pytorch.Linear(
if module_class == te.Linear:
native_module = te.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.pytorch.LayerNormLinear:
native_module = te.pytorch.LayerNormLinear(
elif module_class == te.LayerNormLinear:
native_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
......@@ -176,16 +174,16 @@ def check_nvfp4_module_versus_reference(
# Create reference module
print("Create reference module")
if module_class == te.pytorch.Linear:
ref_module = te.pytorch.Linear(
if module_class == te.Linear:
ref_module = te.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
params_dtype=x_dtype,
)
elif module_class == te.pytorch.LayerNormLinear:
ref_module = te.pytorch.LayerNormLinear(
elif module_class == te.LayerNormLinear:
ref_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
......@@ -232,13 +230,13 @@ def check_nvfp4_module_versus_reference(
grad_output = grad_output_val.clone().detach()
# Native forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
with te.autocast(enabled=True, recipe=nvfp4_recipe):
# enable weight cache by giving is_first_microbatch
y_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe):
with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref = ref_module(x_ref)
y_ref.backward(grad_output)
......@@ -361,7 +359,7 @@ def test_nvfp4_linear_versus_reference(
pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_module_versus_reference(
module_class=te.pytorch.Linear,
module_class=te.Linear,
in_features=in_features,
out_features=out_features,
bias=bias,
......@@ -394,7 +392,7 @@ def check_nvfp4_layernorm_linear_versus_reference(
reset_rng_states()
# Native module
native_module = te.pytorch.LayerNormLinear(
native_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
......@@ -406,7 +404,7 @@ def check_nvfp4_layernorm_linear_versus_reference(
# Reference module
reset_rng_states()
ref_module = te.pytorch.LayerNormLinear(
ref_module = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
......@@ -456,12 +454,12 @@ def check_nvfp4_layernorm_linear_versus_reference(
grad_output = grad_output_val.clone().detach()
# Native forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
with te.autocast(enabled=True, recipe=nvfp4_recipe):
y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe):
with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref, ln_out_ref = ref_module(x_ref)
y_ref.backward(grad_output)
......
......@@ -4,20 +4,16 @@
import pytest
import torch
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
......
......@@ -9,25 +9,18 @@
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
from typing import Tuple
import math
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
import pytest
import torch
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
......
......@@ -4,10 +4,10 @@
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch as te
from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
seed = 12345
torch.manual_seed(seed)
......
......@@ -12,13 +12,15 @@ import pathlib
import pytest
import torch
from typing import Optional
import transformer_engine.pytorch as te
from utils import make_recipe
# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Test cases for loading checkpoint files
......@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if name == "ops_linear":
return te.ops.Linear(1, 1)
if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16)
if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16)
if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32)
if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})")
......
......@@ -12,14 +12,13 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
......@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.fp8_autocast(
with te.autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
......@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.fp8_autocast(
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
......
......@@ -13,20 +13,23 @@ from transformer_engine.pytorch import (
Linear,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
autocast,
quantized_model_init,
make_graphed_callables,
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
is_bf16_available,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported.
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
......@@ -93,7 +96,7 @@ if fp8_available:
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
......@@ -201,7 +204,7 @@ def _test_cuda_graphs(
fp8_weight_caching = False
# Create modules.
with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
if module == "transformer":
modules = [
TransformerLayer(
......@@ -281,9 +284,9 @@ def _test_cuda_graphs(
model,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
enabled=fp8,
cache_quantized_params=fp8_weight_caching,
recipe=fp8_recipe,
)
elif graph_mode == "individual":
# Graph individual modules.
......@@ -292,9 +295,9 @@ def _test_cuda_graphs(
module,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
enabled=fp8,
cache_quantized_params=fp8_weight_caching,
recipe=fp8_recipe,
)
for module in modules
]
......@@ -311,7 +314,7 @@ def _test_cuda_graphs(
for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8, recipe=fp8_recipe):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
......@@ -455,7 +458,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10,
fp8_enabled=False,
enabled=False,
)
# Forward and backward passes.
......
......@@ -5,23 +5,23 @@
import pytest
import torch
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast
from transformer_engine.pytorch import Linear
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
autocast,
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.module.grouped_linear import GroupedLinear
import transformer_engine.pytorch.ops as te_ops
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
def test_custom_recipe_sanity(module_type):
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......@@ -57,7 +57,7 @@ def test_custom_recipe_sanity(module_type):
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
# Execute with custom recipe
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
......@@ -67,7 +67,7 @@ def test_custom_recipe_sanity(module_type):
def test_custom_recipe_grouped_linear_sanity():
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......@@ -93,7 +93,7 @@ def test_custom_recipe_grouped_linear_sanity():
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp, m_splits)
loss = out.float().sum()
loss.backward()
......@@ -102,7 +102,7 @@ def test_custom_recipe_grouped_linear_sanity():
def test_custom_recipe_matches_current_scaling():
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......@@ -124,7 +124,7 @@ def test_custom_recipe_matches_current_scaling():
# Reference: use Float8CurrentScaling recipe
ref_recipe = recipe.Float8CurrentScaling()
with fp8_autocast(enabled=True, fp8_recipe=ref_recipe):
with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -155,7 +155,7 @@ def test_custom_recipe_matches_current_scaling():
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -189,7 +189,7 @@ def test_custom_recipe_matches_current_scaling():
def test_custom_recipe_ops_linear_2_1_layout():
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......@@ -212,7 +212,7 @@ def test_custom_recipe_ops_linear_2_1_layout():
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom):
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
......@@ -221,7 +221,7 @@ def test_custom_recipe_ops_linear_2_1_layout():
def test_custom_recipe_factory_invocation_counts_and_cycling():
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......@@ -256,7 +256,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with fp8_autocast(enabled=True, fp8_recipe=custom):
with autocast(enabled=True, recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
......@@ -270,7 +270,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():
def test_factories_return_distinct_instances_and_buffers():
available, reason = check_fp8_support()
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
......
......@@ -4,7 +4,6 @@
import pytest
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
......
......@@ -4,22 +4,20 @@
import pytest
import torch
import transformer_engine as te
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
supported = te.is_fp8_block_scaling_available()
emulated = get_device_compute_capability() >= (10, 0)
return supported and not emulated
......
......@@ -8,15 +8,12 @@ import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
......@@ -32,7 +29,7 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True)
recipe_emulated = get_device_compute_capability() >= (10, 0)
......
......@@ -8,11 +8,9 @@ import torch
import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
......@@ -23,7 +21,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class GetRecipes:
......@@ -394,7 +392,7 @@ class TestFP8RecipeLinearBase:
# recipe1
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
with autocast(enabled=True, recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
......@@ -402,7 +400,7 @@ class TestFP8RecipeLinearBase:
# recipe2
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
with autocast(enabled=True, recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
......@@ -617,7 +615,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
with autocast(enabled=True, recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
......@@ -639,7 +637,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
with autocast(enabled=True, recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
......
......@@ -11,12 +11,11 @@ import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
from transformer_engine.pytorch import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
get_device_compute_capability,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex
# PyTorch tensor dtypes
......
......@@ -11,13 +11,11 @@ import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
......@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
# delayed scaling
......
......@@ -11,14 +11,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class TestFusedOptimizer:
......@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context = quantized_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
......@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
......@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
......@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
......@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
......@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True, recipe=DelayedScaling()):
with quantized_model_init(enabled=True, recipe=DelayedScaling()):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
......
......@@ -7,8 +7,6 @@ from __future__ import annotations
from collections.abc import Iterable
import io
import math
import pathlib
import sys
from typing import Optional
import pytest
......@@ -17,7 +15,6 @@ import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
......@@ -28,28 +25,27 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
from transformer_engine.pytorch import (
QuantizedTensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
# Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
if is_bf16_available(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16)
# Supported devices
......@@ -372,7 +368,7 @@ class TestFuser:
)
# Construct model
with te.fp8_model_init(recipe=recipe):
with te.quantized_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear(
size,
size,
......@@ -404,7 +400,7 @@ class TestFuser:
)
# Training step
with te.fp8_autocast(fp8_recipe=recipe):
with te.autocast(recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
......@@ -473,7 +469,7 @@ class TestFuser:
)
# Construct operation
with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad():
op.weight.copy_(w_test)
......@@ -530,7 +526,7 @@ class TestFuser:
# Construct operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass
......@@ -540,7 +536,7 @@ class TestFuser:
device=device,
requires_grad=True,
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x)
y.backward(torch.zeros_like(y))
......@@ -553,7 +549,7 @@ class TestFuser:
x.grad = None
op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
......@@ -803,7 +799,7 @@ class TestBasicOps:
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
......@@ -897,7 +893,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
......@@ -914,7 +910,7 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -1075,7 +1071,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.Linear(
in_features,
out_features,
......@@ -1091,7 +1087,7 @@ class TestBasicOps:
del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
......@@ -1192,7 +1188,7 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -1354,7 +1350,7 @@ class TestBasicOps:
op,
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -1654,7 +1650,7 @@ class TestBasicOps:
make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -1721,7 +1717,7 @@ class TestBasicOps:
te_ops.SwiGLU(),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -1792,7 +1788,7 @@ class TestBasicOps:
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -2002,7 +1998,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2018,7 +2014,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -2112,7 +2108,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2129,7 +2125,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -2218,7 +2214,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2234,7 +2230,7 @@ class TestFusedOps:
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -2325,7 +2321,7 @@ class TestFusedOps:
with torch.no_grad():
model[1].bias.copy_(b_test)
del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -2503,7 +2499,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear(
......@@ -2517,7 +2513,7 @@ class TestFusedOps:
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
......@@ -2598,7 +2594,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
......@@ -2612,7 +2608,7 @@ class TestFusedOps:
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test)
(y_test * dy_test).sum().backward()
......@@ -2672,7 +2668,7 @@ class TestCheckpointing:
# Construct model
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
......@@ -2683,7 +2679,7 @@ class TestCheckpointing:
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(x)
y.backward(dy)
optim_save.step()
......@@ -2712,14 +2708,14 @@ class TestCheckpointing:
ys_save = []
for i in range(post_checkpoint_steps):
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(xs_save[i])
y.backward(dys[i])
optim_save.step()
ys_save.append(y)
# Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
......@@ -2732,7 +2728,7 @@ class TestCheckpointing:
ys_load = []
for i in range(post_checkpoint_steps):
optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_load(xs_load[i])
y.backward(dys[i])
optim_load.step()
......@@ -2819,7 +2815,7 @@ class TestSequentialModules:
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm":
norm = te_ops.LayerNorm(
hidden_size,
......@@ -2850,6 +2846,6 @@ class TestSequentialModules:
dtype=dtype,
)
forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -6,7 +6,7 @@ import pytest
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch import TransformerLayer
class SimpleTEModel(PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment