Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
...@@ -7,8 +7,7 @@ import sys ...@@ -7,8 +7,7 @@ import sys
import pytest import pytest
import torch import torch
import transformer_engine import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
......
...@@ -6,13 +6,12 @@ import os ...@@ -6,13 +6,12 @@ import os
import pytest import pytest
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from transformer_engine.pytorch import torch_version import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch import torch
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
NUM_PROCS: int = torch.cuda.device_count() NUM_PROCS: int = torch.cuda.device_count()
...@@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims): ...@@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims):
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims): def test_distributed(fp8_init, sharding_dims):
......
...@@ -4,16 +4,15 @@ ...@@ -4,16 +4,15 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def check_nvfp4_gemm_versus_reference( def check_nvfp4_gemm_versus_reference(
......
...@@ -4,15 +4,13 @@ ...@@ -4,15 +4,13 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
class GetRecipes: class GetRecipes:
...@@ -152,16 +150,16 @@ def check_nvfp4_module_versus_reference( ...@@ -152,16 +150,16 @@ def check_nvfp4_module_versus_reference(
# Create native module # Create native module
print("\nCreate native module") print("\nCreate native module")
if module_class == te.pytorch.Linear: if module_class == te.Linear:
native_module = te.pytorch.Linear( native_module = te.Linear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
device=device, device=device,
params_dtype=x_dtype, params_dtype=x_dtype,
) )
elif module_class == te.pytorch.LayerNormLinear: elif module_class == te.LayerNormLinear:
native_module = te.pytorch.LayerNormLinear( native_module = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
...@@ -176,16 +174,16 @@ def check_nvfp4_module_versus_reference( ...@@ -176,16 +174,16 @@ def check_nvfp4_module_versus_reference(
# Create reference module # Create reference module
print("Create reference module") print("Create reference module")
if module_class == te.pytorch.Linear: if module_class == te.Linear:
ref_module = te.pytorch.Linear( ref_module = te.Linear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
device=device, device=device,
params_dtype=x_dtype, params_dtype=x_dtype,
) )
elif module_class == te.pytorch.LayerNormLinear: elif module_class == te.LayerNormLinear:
ref_module = te.pytorch.LayerNormLinear( ref_module = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
...@@ -232,13 +230,13 @@ def check_nvfp4_module_versus_reference( ...@@ -232,13 +230,13 @@ def check_nvfp4_module_versus_reference(
grad_output = grad_output_val.clone().detach() grad_output = grad_output_val.clone().detach()
# Native forward/backward # Native forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): with te.autocast(enabled=True, recipe=nvfp4_recipe):
# enable weight cache by giving is_first_microbatch # enable weight cache by giving is_first_microbatch
y_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output) y_native.backward(grad_output)
# Reference forward/backward # Reference forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref = ref_module(x_ref) y_ref = ref_module(x_ref)
y_ref.backward(grad_output) y_ref.backward(grad_output)
...@@ -361,7 +359,7 @@ def test_nvfp4_linear_versus_reference( ...@@ -361,7 +359,7 @@ def test_nvfp4_linear_versus_reference(
pytest.skip("RHT is only supported for bfloat16 input") pytest.skip("RHT is only supported for bfloat16 input")
check_nvfp4_module_versus_reference( check_nvfp4_module_versus_reference(
module_class=te.pytorch.Linear, module_class=te.Linear,
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
...@@ -394,7 +392,7 @@ def check_nvfp4_layernorm_linear_versus_reference( ...@@ -394,7 +392,7 @@ def check_nvfp4_layernorm_linear_versus_reference(
reset_rng_states() reset_rng_states()
# Native module # Native module
native_module = te.pytorch.LayerNormLinear( native_module = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
...@@ -406,7 +404,7 @@ def check_nvfp4_layernorm_linear_versus_reference( ...@@ -406,7 +404,7 @@ def check_nvfp4_layernorm_linear_versus_reference(
# Reference module # Reference module
reset_rng_states() reset_rng_states()
ref_module = te.pytorch.LayerNormLinear( ref_module = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
...@@ -456,12 +454,12 @@ def check_nvfp4_layernorm_linear_versus_reference( ...@@ -456,12 +454,12 @@ def check_nvfp4_layernorm_linear_versus_reference(
grad_output = grad_output_val.clone().detach() grad_output = grad_output_val.clone().detach()
# Native forward/backward # Native forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): with te.autocast(enabled=True, recipe=nvfp4_recipe):
y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output) y_native.backward(grad_output)
# Reference forward/backward # Reference forward/backward
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): with te.autocast(enabled=True, recipe=nvfp4_ref_recipe):
y_ref, ln_out_ref = ref_module(x_ref) y_ref, ln_out_ref = ref_module(x_ref)
y_ref.backward(grad_output) y_ref.backward(grad_output)
......
...@@ -4,20 +4,16 @@ ...@@ -4,20 +4,16 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
......
...@@ -9,25 +9,18 @@ ...@@ -9,25 +9,18 @@
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality # Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality. # together with the quantization functionality.
from typing import Tuple import transformer_engine.pytorch as te
import math
import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
import pytest import pytest
import torch import torch
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
seed = 12345 seed = 12345
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -12,13 +12,15 @@ import pathlib ...@@ -12,13 +12,15 @@ import pathlib
import pytest import pytest
import torch import torch
from typing import Optional
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from utils import make_recipe from utils import make_recipe
# Check supported quantization schemes # Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
# Test cases for loading checkpoint files # Test cases for loading checkpoint files
...@@ -65,16 +67,16 @@ class TestLoadCheckpoint: ...@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if name == "ops_linear": if name == "ops_linear":
return te.ops.Linear(1, 1) return te.ops.Linear(1, 1)
if name == "linear.fp8": if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")): with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16) return te.Linear(16, 16)
if name == "ops_linear.fp8": if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")): with te.quantized_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16) return te.ops.Linear(16, 16)
if name == "linear.mxfp8": if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")): with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32) return te.Linear(32, 32)
if name == "ops_linear.mxfp8": if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")): with te.quantized_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32) return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})") raise ValueError(f"Unrecognized module name ({name})")
......
...@@ -12,14 +12,13 @@ import torch ...@@ -12,14 +12,13 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes # Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available = te.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None] quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available: if fp8_available:
...@@ -79,9 +78,9 @@ def _warmup_model( ...@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass""" """Perform forward and backward pass"""
tensor = _make_input() tensor = _make_input()
for module in modules: for module in modules:
with te.fp8_autocast( with te.autocast(
enabled=quantization_recipe is not None, enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe, recipe=quantization_recipe,
): ):
tensor = module(tensor) tensor = module(tensor)
tensor.sum().backward() tensor.sum().backward()
...@@ -159,8 +158,8 @@ def _measure_cached_memory( ...@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor = inp tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2) memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules: for module in modules:
with te.fp8_autocast( with te.autocast(
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context: ), offload_context:
tensor = module(tensor) tensor = module(tensor)
tensor = sync_function(tensor) tensor = sync_function(tensor)
......
...@@ -13,20 +13,23 @@ from transformer_engine.pytorch import ( ...@@ -13,20 +13,23 @@ from transformer_engine.pytorch import (
Linear, Linear,
MultiheadAttention, MultiheadAttention,
TransformerLayer, TransformerLayer,
fp8_autocast, autocast,
fp8_model_init, quantized_model_init,
make_graphed_callables, make_graphed_callables,
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
is_bf16_available,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available = is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available = is_mxfp8_available()
# Reset RNG states. # Reset RNG states.
reset_rng_states() reset_rng_states()
...@@ -93,7 +96,7 @@ if fp8_available: ...@@ -93,7 +96,7 @@ if fp8_available:
# Supported data types # Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16] dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
...@@ -201,7 +204,7 @@ def _test_cuda_graphs( ...@@ -201,7 +204,7 @@ def _test_cuda_graphs(
fp8_weight_caching = False fp8_weight_caching = False
# Create modules. # Create modules.
with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
if module == "transformer": if module == "transformer":
modules = [ modules = [
TransformerLayer( TransformerLayer(
...@@ -281,9 +284,9 @@ def _test_cuda_graphs( ...@@ -281,9 +284,9 @@ def _test_cuda_graphs(
model, model,
(generate_data(model_config, dtype, warmup=True),), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, enabled=fp8,
fp8_weight_caching=fp8_weight_caching, cache_quantized_params=fp8_weight_caching,
fp8_recipe=fp8_recipe, recipe=fp8_recipe,
) )
elif graph_mode == "individual": elif graph_mode == "individual":
# Graph individual modules. # Graph individual modules.
...@@ -292,9 +295,9 @@ def _test_cuda_graphs( ...@@ -292,9 +295,9 @@ def _test_cuda_graphs(
module, module,
(generate_data(model_config, dtype, warmup=True),), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, enabled=fp8,
fp8_weight_caching=fp8_weight_caching, cache_quantized_params=fp8_weight_caching,
fp8_recipe=fp8_recipe, recipe=fp8_recipe,
) )
for module in modules for module in modules
] ]
...@@ -311,7 +314,7 @@ def _test_cuda_graphs( ...@@ -311,7 +314,7 @@ def _test_cuda_graphs(
for grad_accumulation_step in range(2): for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype) input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False) grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): with autocast(enabled=fp8, recipe=fp8_recipe):
kwargs = {} kwargs = {}
if fp8_weight_caching: if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0 kwargs["is_first_microbatch"] = grad_accumulation_step == 0
...@@ -455,7 +458,7 @@ def _test_cuda_graphs_with_dot_product_attention( ...@@ -455,7 +458,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model, model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True), generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=False, enabled=False,
) )
# Forward and backward passes. # Forward and backward passes.
......
...@@ -5,23 +5,23 @@ ...@@ -5,23 +5,23 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast from transformer_engine.pytorch import (
from transformer_engine.pytorch import Linear autocast,
import transformer_engine.pytorch.ops as te_ops Linear,
from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear LayerNormLinear,
from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP LayerNormMLP,
from transformer_engine.pytorch.tensor.float8_tensor import ( GroupedLinear,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.module.grouped_linear import GroupedLinear import transformer_engine.pytorch.ops as te_ops
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) @pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
def test_custom_recipe_sanity(module_type): def test_custom_recipe_sanity(module_type):
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
...@@ -57,7 +57,7 @@ def test_custom_recipe_sanity(module_type): ...@@ -57,7 +57,7 @@ def test_custom_recipe_sanity(module_type):
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
# Execute with custom recipe # Execute with custom recipe
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): with autocast(enabled=True, recipe=custom_recipe):
out = model(inp) out = model(inp)
loss = out.float().sum() loss = out.float().sum()
loss.backward() loss.backward()
...@@ -67,7 +67,7 @@ def test_custom_recipe_sanity(module_type): ...@@ -67,7 +67,7 @@ def test_custom_recipe_sanity(module_type):
def test_custom_recipe_grouped_linear_sanity(): def test_custom_recipe_grouped_linear_sanity():
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
...@@ -93,7 +93,7 @@ def test_custom_recipe_grouped_linear_sanity(): ...@@ -93,7 +93,7 @@ def test_custom_recipe_grouped_linear_sanity():
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): with autocast(enabled=True, recipe=custom_recipe):
out = model(inp, m_splits) out = model(inp, m_splits)
loss = out.float().sum() loss = out.float().sum()
loss.backward() loss.backward()
...@@ -102,7 +102,7 @@ def test_custom_recipe_grouped_linear_sanity(): ...@@ -102,7 +102,7 @@ def test_custom_recipe_grouped_linear_sanity():
def test_custom_recipe_matches_current_scaling(): def test_custom_recipe_matches_current_scaling():
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
...@@ -124,7 +124,7 @@ def test_custom_recipe_matches_current_scaling(): ...@@ -124,7 +124,7 @@ def test_custom_recipe_matches_current_scaling():
# Reference: use Float8CurrentScaling recipe # Reference: use Float8CurrentScaling recipe
ref_recipe = recipe.Float8CurrentScaling() ref_recipe = recipe.Float8CurrentScaling()
with fp8_autocast(enabled=True, fp8_recipe=ref_recipe): with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref) out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -155,7 +155,7 @@ def test_custom_recipe_matches_current_scaling(): ...@@ -155,7 +155,7 @@ def test_custom_recipe_matches_current_scaling():
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom) out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping # Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -189,7 +189,7 @@ def test_custom_recipe_matches_current_scaling(): ...@@ -189,7 +189,7 @@ def test_custom_recipe_matches_current_scaling():
def test_custom_recipe_ops_linear_2_1_layout(): def test_custom_recipe_ops_linear_2_1_layout():
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
...@@ -212,7 +212,7 @@ def test_custom_recipe_ops_linear_2_1_layout(): ...@@ -212,7 +212,7 @@ def test_custom_recipe_ops_linear_2_1_layout():
custom = recipe.CustomRecipe(qfactory=quantizer_factory) custom = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom): with autocast(enabled=True, recipe=custom):
out = op(inp) out = op(inp)
loss = out.float().sum() loss = out.float().sum()
loss.backward() loss.backward()
...@@ -221,7 +221,7 @@ def test_custom_recipe_ops_linear_2_1_layout(): ...@@ -221,7 +221,7 @@ def test_custom_recipe_ops_linear_2_1_layout():
def test_custom_recipe_factory_invocation_counts_and_cycling(): def test_custom_recipe_factory_invocation_counts_and_cycling():
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
...@@ -256,7 +256,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): ...@@ -256,7 +256,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory). # and backward to build 2 quantizers (cycled from 1 factory).
with fp8_autocast(enabled=True, fp8_recipe=custom): with autocast(enabled=True, recipe=custom):
out = op(inp) out = op(inp)
loss = out.float().sum() loss = out.float().sum()
loss.backward() loss.backward()
...@@ -270,7 +270,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): ...@@ -270,7 +270,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():
def test_factories_return_distinct_instances_and_buffers(): def test_factories_return_distinct_instances_and_buffers():
available, reason = check_fp8_support() available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available: if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}") pytest.skip(f"FP8 unsupported on this device: {reason}")
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
......
...@@ -4,22 +4,20 @@ ...@@ -4,22 +4,20 @@
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch import (
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, get_device_compute_capability,
) )
from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool: def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() supported = te.is_fp8_block_scaling_available()
emulated = get_device_compute_capability() >= (10, 0) emulated = get_device_compute_capability() >= (10, 0)
return supported and not emulated return supported and not emulated
......
...@@ -8,15 +8,12 @@ import os ...@@ -8,15 +8,12 @@ import os
import pathlib import pathlib
import pytest import pytest
import torch import torch
import transformer_engine as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, get_device_compute_capability,
) )
from references.blockwise_quantizer_reference import ( from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference, BlockwiseQuantizerReference,
...@@ -32,7 +29,7 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso ...@@ -32,7 +29,7 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None: if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True)
recipe_emulated = get_device_compute_capability() >= (10, 0) recipe_emulated = get_device_compute_capability() >= (10, 0)
......
...@@ -8,11 +8,9 @@ import torch ...@@ -8,11 +8,9 @@ import torch
import pytest import pytest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
...@@ -23,7 +21,7 @@ if tensor_dump_dir_env is not None: ...@@ -23,7 +21,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class GetRecipes: class GetRecipes:
...@@ -394,7 +392,7 @@ class TestFP8RecipeLinearBase: ...@@ -394,7 +392,7 @@ class TestFP8RecipeLinearBase:
# recipe1 # recipe1
using_fp8_recipe = recipe1() != GetRecipes.none() using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with autocast(enabled=True, recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else: else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
...@@ -402,7 +400,7 @@ class TestFP8RecipeLinearBase: ...@@ -402,7 +400,7 @@ class TestFP8RecipeLinearBase:
# recipe2 # recipe2
using_fp8_recipe = recipe2() != GetRecipes.none() using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with autocast(enabled=True, recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else: else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
...@@ -617,7 +615,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -617,7 +615,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1 # recipe1
using_fp8_recipe = recipe1() != GetRecipes.none() using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with autocast(enabled=True, recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x, x,
w, w,
...@@ -639,7 +637,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -639,7 +637,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2 # recipe2
using_fp8_recipe = recipe2() != GetRecipes.none() using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with autocast(enabled=True, recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x, x,
w, w,
......
...@@ -11,12 +11,11 @@ import pytest ...@@ -11,12 +11,11 @@ import pytest
import torch import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
get_device_compute_capability,
) )
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex import transformer_engine_torch as tex
# PyTorch tensor dtypes # PyTorch tensor dtypes
......
...@@ -11,13 +11,11 @@ import torch ...@@ -11,13 +11,11 @@ import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8Tensor, Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: ...@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int] DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
# delayed scaling # delayed scaling
......
...@@ -11,14 +11,11 @@ from torch import nn ...@@ -11,14 +11,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
class TestFusedOptimizer: class TestFusedOptimizer:
...@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} build_model_context_args = {}
if use_fp8_params: if use_fp8_params:
build_model_context = fp8_model_init build_model_context = quantized_model_init
build_model_context_args["enabled"] = True build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
...@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master(self): def test_fp32_master(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self): def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders=True, store_param_remainders=True,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_master(self): def test_fp16_master(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_grad(self): def test_bf16_grad(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg(self): def test_fp16_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg(self): def test_bf16_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self): def test_fp8_exp_avg(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
...@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=1e-2, master_atol=1e-2,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self): def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self): def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
use_fp8_params=False, use_fp8_params=False,
...@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3, master_atol=2e-3,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self): def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test( self.gen_precision_aware_test(
...@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert=True, skip_assert=True,
) )
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self): def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
model = MultiheadAttention( model = MultiheadAttention(
...@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self): def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
with fp8_model_init(enabled=True, recipe=DelayedScaling()): with quantized_model_init(enabled=True, recipe=DelayedScaling()):
model = MultiheadAttention( model = MultiheadAttention(
hidden_size=1024, hidden_size=1024,
num_attention_heads=16, num_attention_heads=16,
......
...@@ -7,8 +7,6 @@ from __future__ import annotations ...@@ -7,8 +7,6 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import io import io
import math import math
import pathlib
import sys
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -17,7 +15,6 @@ import torch ...@@ -17,7 +15,6 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias, BackwardActivationBias,
...@@ -28,28 +25,27 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -28,28 +25,27 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
ForwardLinearScaleAdd, ForwardLinearScaleAdd,
) )
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch import (
from transformer_engine.pytorch.tensor.float8_tensor import ( QuantizedTensor,
Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8Quantizer, Float8Quantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
# Check for supported quantization schemes # Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# Supported data types # Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16] _dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_available(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16) _dtypes.append(torch.bfloat16)
# Supported devices # Supported devices
...@@ -372,7 +368,7 @@ class TestFuser: ...@@ -372,7 +368,7 @@ class TestFuser:
) )
# Construct model # Construct model
with te.fp8_model_init(recipe=recipe): with te.quantized_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear( model = te_ops.basic.BasicLinear(
size, size,
size, size,
...@@ -404,7 +400,7 @@ class TestFuser: ...@@ -404,7 +400,7 @@ class TestFuser:
) )
# Training step # Training step
with te.fp8_autocast(fp8_recipe=recipe): with te.autocast(recipe=recipe):
y = model(x) y = model(x)
y.backward(dy) y.backward(dy)
with torch.no_grad(): with torch.no_grad():
...@@ -473,7 +469,7 @@ class TestFuser: ...@@ -473,7 +469,7 @@ class TestFuser:
) )
# Construct operation # Construct operation
with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
...@@ -530,7 +526,7 @@ class TestFuser: ...@@ -530,7 +526,7 @@ class TestFuser:
# Construct operation # Construct operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): with te.quantized_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass # Check forward and backward pass
...@@ -540,7 +536,7 @@ class TestFuser: ...@@ -540,7 +536,7 @@ class TestFuser:
device=device, device=device,
requires_grad=True, requires_grad=True,
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype): with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x) y = op(x)
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
...@@ -553,7 +549,7 @@ class TestFuser: ...@@ -553,7 +549,7 @@ class TestFuser:
x.grad = None x.grad = None
op.weight.grad = None op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype): with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = op(x) y = op(x)
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype assert y.dtype == autocast_dtype
...@@ -803,7 +799,7 @@ class TestBasicOps: ...@@ -803,7 +799,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -897,7 +893,7 @@ class TestBasicOps: ...@@ -897,7 +893,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear( op = te_ops.BasicLinear(
in_features, in_features,
out_features, out_features,
...@@ -914,7 +910,7 @@ class TestBasicOps: ...@@ -914,7 +910,7 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1075,7 +1071,7 @@ class TestBasicOps: ...@@ -1075,7 +1071,7 @@ class TestBasicOps:
# Implementation with fusible operation # Implementation with fusible operation
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.Linear( op = te_ops.Linear(
in_features, in_features,
out_features, out_features,
...@@ -1091,7 +1087,7 @@ class TestBasicOps: ...@@ -1091,7 +1087,7 @@ class TestBasicOps:
del b_test del b_test
for param in op.parameters(): for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad) param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test) y_test = op(x_test)
if input_requires_grad or weight_requires_grad: if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1192,7 +1188,7 @@ class TestBasicOps: ...@@ -1192,7 +1188,7 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1354,7 +1350,7 @@ class TestBasicOps: ...@@ -1354,7 +1350,7 @@ class TestBasicOps:
op, op,
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1654,7 +1650,7 @@ class TestBasicOps: ...@@ -1654,7 +1650,7 @@ class TestBasicOps:
make_op(cache_quantized_input=cache_quantized_input), make_op(cache_quantized_input=cache_quantized_input),
te_ops.Quantize(forward=quantized_compute, backward=False), te_ops.Quantize(forward=quantized_compute, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1721,7 +1717,7 @@ class TestBasicOps: ...@@ -1721,7 +1717,7 @@ class TestBasicOps:
te_ops.SwiGLU(), te_ops.SwiGLU(),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1792,7 +1788,7 @@ class TestBasicOps: ...@@ -1792,7 +1788,7 @@ class TestBasicOps:
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2002,7 +1998,7 @@ class TestFusedOps: ...@@ -2002,7 +1998,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): with te.quantized_model_init(enabled=quantized_compute, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2018,7 +2014,7 @@ class TestFusedOps: ...@@ -2018,7 +2014,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test) model[0].bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2112,7 +2108,7 @@ class TestFusedOps: ...@@ -2112,7 +2108,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2129,7 +2125,7 @@ class TestFusedOps: ...@@ -2129,7 +2125,7 @@ class TestFusedOps:
model[0].bias.copy_(b_test) model[0].bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test) y_test = model(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2218,7 +2214,7 @@ class TestFusedOps: ...@@ -2218,7 +2214,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2234,7 +2230,7 @@ class TestFusedOps: ...@@ -2234,7 +2230,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x1_test, x2_test) y_test = model(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2325,7 +2321,7 @@ class TestFusedOps: ...@@ -2325,7 +2321,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[1].bias.copy_(b_test) model[1].bias.copy_(b_test)
del b_test del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -2503,7 +2499,7 @@ class TestFusedOps: ...@@ -2503,7 +2499,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.MakeExtraOutput(in_place=True), te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear( te_ops.Linear(
...@@ -2517,7 +2513,7 @@ class TestFusedOps: ...@@ -2517,7 +2513,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[1].weight.copy_(w_test) model[1].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y1_test, y2_test = model(x_test) y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward() (y1_test * dy1_test + y2_test * dy2_test).sum().backward()
...@@ -2598,7 +2594,7 @@ class TestFusedOps: ...@@ -2598,7 +2594,7 @@ class TestFusedOps:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.quantized_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.Linear( te_ops.Linear(
in_features, in_features,
...@@ -2612,7 +2608,7 @@ class TestFusedOps: ...@@ -2612,7 +2608,7 @@ class TestFusedOps:
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
del w_test del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
(y_test * dy_test).sum().backward() (y_test * dy_test).sum().backward()
...@@ -2672,7 +2668,7 @@ class TestCheckpointing: ...@@ -2672,7 +2668,7 @@ class TestCheckpointing:
# Construct model # Construct model
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential( model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype) te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
) )
...@@ -2683,7 +2679,7 @@ class TestCheckpointing: ...@@ -2683,7 +2679,7 @@ class TestCheckpointing:
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device) dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad() optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(x) y = model_save(x)
y.backward(dy) y.backward(dy)
optim_save.step() optim_save.step()
...@@ -2712,14 +2708,14 @@ class TestCheckpointing: ...@@ -2712,14 +2708,14 @@ class TestCheckpointing:
ys_save = [] ys_save = []
for i in range(post_checkpoint_steps): for i in range(post_checkpoint_steps):
optim_save.zero_grad() optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_save(xs_save[i]) y = model_save(xs_save[i])
y.backward(dys[i]) y.backward(dys[i])
optim_save.step() optim_save.step()
ys_save.append(y) ys_save.append(y)
# Load checkpoint # Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential( model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype) te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
) )
...@@ -2732,7 +2728,7 @@ class TestCheckpointing: ...@@ -2732,7 +2728,7 @@ class TestCheckpointing:
ys_load = [] ys_load = []
for i in range(post_checkpoint_steps): for i in range(post_checkpoint_steps):
optim_load.zero_grad() optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y = model_load(xs_load[i]) y = model_load(xs_load[i])
y.backward(dys[i]) y.backward(dys[i])
optim_load.step() optim_load.step()
...@@ -2819,7 +2815,7 @@ class TestSequentialModules: ...@@ -2819,7 +2815,7 @@ class TestSequentialModules:
# Implementation with fusible operations # Implementation with fusible operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm": if normalization == "LayerNorm":
norm = te_ops.LayerNorm( norm = te_ops.LayerNorm(
hidden_size, hidden_size,
...@@ -2850,6 +2846,6 @@ class TestSequentialModules: ...@@ -2850,6 +2846,6 @@ class TestSequentialModules:
dtype=dtype, dtype=dtype,
) )
forward = te_ops.Sequential(norm, ffn1, act, ffn2) forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch import TransformerLayer
class SimpleTEModel(PreTrainedModel): class SimpleTEModel(PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment