Unverified Commit aac74427 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Prune L0 unit test (#1999)



* Add verbosity only for failing tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Prune some tests and preinit recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Prune further tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix multitensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix a100
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c6c1f50e
...@@ -26,30 +26,30 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -26,30 +26,30 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -6,7 +6,7 @@ import math ...@@ -6,7 +6,7 @@ import math
import os import os
import sys import sys
import pathlib import pathlib
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -20,10 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( ...@@ -20,10 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils, FlashAttentionUtils,
get_attention_backend,
check_set_window_size, check_set_window_size,
) )
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...@@ -54,7 +52,6 @@ from utils import ( ...@@ -54,7 +52,6 @@ from utils import (
reset_rng_states, reset_rng_states,
ModelConfig, ModelConfig,
dtype_tols, dtype_tols,
logging_context,
get_available_attention_backends, get_available_attention_backends,
) )
......
...@@ -14,15 +14,12 @@ from transformer_engine.pytorch.attention.dot_product_attention import _attentio ...@@ -14,15 +14,12 @@ from transformer_engine.pytorch.attention.dot_product_attention import _attentio
from utils import ModelConfig, get_available_attention_backends from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [ fp8_recipes = [None]
None, # non-fp8 if fp8_available:
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet fp8_recipes.append(recipe.Float8CurrentScaling())
recipe.Float8CurrentScaling(), fp8_recipes.append(recipe.DelayedScaling())
recipe.DelayedScaling(),
]
model_config = { model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
...@@ -129,12 +126,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: ...@@ -129,12 +126,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls = model_types[model_key] model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)] models_list = [model_cls() for _ in range(NUM_LAYERS)]
if fp8_recipe and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if model_key in ["multihead_attention", "transformer_layer"]: if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends( available_backends, *_ = get_available_attention_backends(
model_config["small"], model_config["small"],
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass from typing import Iterable, List, Union
import itertools
from typing import Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -26,11 +24,9 @@ from transformer_engine.common import recipe ...@@ -26,11 +24,9 @@ from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available() mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states. # Reset RNG states.
reset_rng_states() reset_rng_states()
...@@ -39,12 +35,14 @@ model_configs = { ...@@ -39,12 +35,14 @@ model_configs = {
"small": ModelConfig(32, 2, 2, 32), "small": ModelConfig(32, 2, 2, 32),
} }
fp8_recipes = [ fp8_recipes = []
recipe.DelayedScaling(), if mxfp8_available:
recipe.MXFP8BlockScaling(), fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.Float8CurrentScaling(), if fp8_block_scaling_available:
recipe.Float8BlockScaling(), fp8_recipes.append(recipe.Float8BlockScaling())
] if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
# Supported data types # Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16] dtypes: List[torch.dtype] = [torch.float32, torch.float16]
...@@ -277,35 +275,27 @@ def _test_cuda_graphs( ...@@ -277,35 +275,27 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
def test_make_graphed_callables( def test_make_graphed_callables(
*, *,
module: str, module: str,
model_config: str = "small", model_config: str = "small",
num_layers: int = 3, num_layers: int = 3,
dtype: torch.dtype, dtype: torch.dtype,
fp8: bool,
fp8_params: bool, fp8_params: bool,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
) -> None: ) -> None:
# Skip invalid configurations. fp8 = fp8_recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8: if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
kwargs = dict( kwargs = dict(
...@@ -336,7 +326,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [ ...@@ -336,7 +326,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
] ]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module", "module",
_test_make_graphed_callables_with_fp8_weight_caching_modules, _test_make_graphed_callables_with_fp8_weight_caching_modules,
...@@ -352,7 +341,6 @@ def test_make_graphed_callables_with_fp8_weight_caching( ...@@ -352,7 +341,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables( test_make_graphed_callables(
module=module, module=module,
dtype=torch.float32, dtype=torch.float32,
fp8=True,
fp8_params=fp8_params, fp8_params=fp8_params,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_weight_caching=True, fp8_weight_caching=True,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from itertools import product
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import torch import torch
import math from typing import Optional
from typing import Optional, Dict
from transformer_engine.pytorch.router import ( from transformer_engine.pytorch.router import (
fused_topk_with_score_function, fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss, fused_compute_score_for_moe_aux_loss,
......
...@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel): class SimpleTEModel(PreTrainedModel):
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from collections import OrderedDict
import math import math
import os import os
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
...@@ -37,23 +36,20 @@ from transformer_engine.pytorch import ( ...@@ -37,23 +36,20 @@ from transformer_engine.pytorch import (
Fp8Padding, Fp8Padding,
Fp8Unpadding, Fp8Unpadding,
) )
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -103,18 +99,21 @@ if NVTE_TEST_NVINSPECT_ENABLED: ...@@ -103,18 +99,21 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
) )
fp8_recipes = [
recipe.MXFP8BlockScaling(), fp8_recipes = []
recipe.DelayedScaling(), if mxfp8_available:
recipe.Float8CurrentScaling(), fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.Float8BlockScaling(), if fp8_block_scaling_available:
] fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
def is_fused_attn_available( def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
): ):
available_backends, _, fused_attn_backends = get_available_attention_backends( _, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -571,14 +570,8 @@ def _test_e2e_selective_recompute( ...@@ -571,14 +570,8 @@ def _test_e2e_selective_recompute(
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -687,14 +680,8 @@ def _test_e2e_full_recompute( ...@@ -687,14 +680,8 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute( def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -1263,8 +1250,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1263,8 +1250,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
) )
# Shoule be bit-wise match # Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -1276,12 +1263,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1276,12 +1263,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True fuse_wgrad_accumulation = True
fp8_model_params = False fp8_model_params = False
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
...@@ -1649,14 +1631,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1649,14 +1631,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute( def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation dtype, bs, model, bias, fuse_wgrad_accumulation
): ):
config = model_configs[model] config = model_configs[model]
...@@ -1665,7 +1645,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1665,7 +1645,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
eps=config.eps, eps=config.eps,
bias=bias, bias=bias,
normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
delay_wgrad_compute=True, delay_wgrad_compute=True,
...@@ -1677,7 +1656,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1677,7 +1656,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
eps=config.eps, eps=config.eps,
bias=bias, bias=bias,
normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
delay_wgrad_compute=False, delay_wgrad_compute=False,
...@@ -1687,7 +1665,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1687,7 +1665,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
...@@ -1802,14 +1779,8 @@ def test_grouped_linear_accuracy( ...@@ -1802,14 +1779,8 @@ def test_grouped_linear_accuracy(
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -1904,14 +1875,8 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1904,14 +1875,8 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
...@@ -2114,14 +2079,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -2114,14 +2079,8 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2189,14 +2148,8 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2189,14 +2148,8 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
...@@ -2410,14 +2363,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2410,14 +2363,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -2645,9 +2592,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): ...@@ -2645,9 +2592,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(16, 4096, 128, 512), (16, 4096, 128, 512),
], ],
) )
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): def test_fp8_grouped_gemm(shape, accumulate):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
......
...@@ -27,7 +27,6 @@ import warnings ...@@ -27,7 +27,6 @@ import warnings
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
import random
from torch import nn as nn from torch import nn as nn
from typing import Optional, Union, Tuple, List from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
...@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) ...@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [ fp8_recipes = []
None, if mxfp8_available:
recipe.DelayedScaling(), fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.MXFP8BlockScaling(), if fp8_available:
] fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -369,14 +367,6 @@ def validate_result( ...@@ -369,14 +367,6 @@ def validate_result(
) )
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False): def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io: if fake_bf16_io:
assert dtype == torch.bfloat16 assert dtype == torch.bfloat16
...@@ -413,36 +403,12 @@ Test cases begin here. ...@@ -413,36 +403,12 @@ Test cases begin here.
""" """
@pytest.mark.parametrize("scale_factor", [112]) def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) fp8_recipe: recipe.Recipe = fp8_recipes[0],
# Returning the bias is a TE fusion optimization we don't care about. use_bias: bool = True,
@pytest.mark.parametrize("return_bias", [True, False]) return_bias: bool = False,
@pytest.mark.parametrize( precision: torch.dtype = torch.float32,
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -498,32 +464,28 @@ def test_export_linear( ...@@ -498,32 +464,28 @@ def test_export_linear(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize( @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
"precision", def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
[ _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
torch.float32,
torch.float16,
torch.bfloat16, @pytest.mark.parametrize("use_bias", [True, False])
], def test_export_linear_use_bias(seed_default_rng, use_bias):
) _test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
_test_export_linear(return_bias=return_bias)
def _test_export_layernorm(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
batch_size = 4 batch_size = 4
in_features = 64 in_features = 64
...@@ -564,39 +526,31 @@ def test_export_layernorm( ...@@ -564,39 +526,31 @@ def test_export_layernorm(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [True, False]) def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision):
@pytest.mark.parametrize( _test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision)
"precision, use_bias",
[
(torch.float32, False), def test_export_layernorm_zero_centered_gamma(seed_default_rng):
(torch.float32, True), _test_export_layernorm(zero_centered_gamma=True)
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear( def test_export_layernorm_normalization(seed_default_rng, normalization):
seed_default_rng, _test_export_layernorm(normalization=normalization)
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool, def _test_export_layernorm_linear(
return_bias: bool, scale_factor: float = 112,
return_layernorm_output: bool, fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype, use_bias: bool = True,
zero_centered_gamma: bool, return_bias: bool = False,
normalization: str, return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -644,41 +598,44 @@ def test_export_layernorm_linear( ...@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [True, False]) def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
@pytest.mark.parametrize( _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
"precision, use_bias",
[
(torch.float32, False), def test_export_layernorm_linear_return_ln_out(seed_default_rng):
(torch.float32, True), _test_export_layernorm_linear(return_layernorm_output=True)
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True), def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
(torch.bfloat16, False), _test_export_layernorm_linear(zero_centered_gamma=True)
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
@pytest.mark.parametrize("activation", supported_activations) def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
@pytest.mark.parametrize("normalization", all_normalizations) _test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float, def test_export_layernorm_linear_no_bias(seed_default_rng):
fp8_recipe: recipe.Recipe, _test_export_layernorm_linear(use_bias=False)
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool, def test_export_layernorm_linear_return_bias(seed_default_rng):
precision: torch.dtype, _test_export_layernorm_linear(return_bias=True)
zero_centered_gamma: bool,
activation: str,
normalization: str, def _test_export_layernorm_mlp(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
normalization: str = all_normalizations[0],
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -720,6 +677,38 @@ def test_export_layernorm_mlp( ...@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
) )
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
_test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng):
_test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng):
_test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
_test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", "precision, use_mask, attn_mask_type",
[ [
...@@ -734,8 +723,6 @@ def test_export_layernorm_mlp( ...@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
], ],
) )
def test_export_core_attention( def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype, precision: torch.dtype,
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
...@@ -777,11 +764,6 @@ def test_export_core_attention( ...@@ -777,11 +764,6 @@ def test_export_core_attention(
) )
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [ test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params" # "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True), (True, "self", True),
...@@ -795,31 +777,14 @@ test_configs_attention_type = [ ...@@ -795,31 +777,14 @@ test_configs_attention_type = [
] ]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def _test_export_multihead_attention(
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) fp8_recipe: recipe.Recipe = fp8_recipes[0],
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) use_mask: bool = True,
@pytest.mark.parametrize("return_layernorm_output", [False]) precision: torch.dtype = torch.float32,
@pytest.mark.parametrize( input_layernorm: bool = True,
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type attention_type: str = "self",
) fuse_qkv_params: bool = True,
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256 hidden_size = 256
sequence_length = 128 sequence_length = 128
batch_size = 4 batch_size = 4
...@@ -837,6 +802,7 @@ def test_export_multihead_attention( ...@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method, init_method,
output_layer_init_method, output_layer_init_method,
) )
attn_mask_type = "arbitrary" if use_mask else "no_mask"
hidden_states_context = torch.randn( hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
...@@ -868,7 +834,7 @@ def test_export_multihead_attention( ...@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*attention_args, *attention_args,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
params_dtype=precision, params_dtype=precision,
return_layernorm_output=return_layernorm_output, return_layernorm_output=False,
input_layernorm=input_layernorm, input_layernorm=input_layernorm,
attention_type=attention_type, attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
...@@ -960,30 +926,37 @@ def test_export_multihead_attention( ...@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) def test_export_multihead_attention_recipe(fp8_recipe, precision):
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng, def test_export_multihead_attention_no_mask():
set_max_seq_len, _test_export_multihead_attention(use_mask=False)
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str, def test_export_multihead_attention_no_input_layernorm():
output_layernorm: bool, _test_export_multihead_attention(input_layernorm=False)
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
def test_export_multihead_attention_cross_attn():
_test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params():
_test_export_multihead_attention(fuse_qkv_params=False)
def _test_export_transformer_layer(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
attn_mask_type: str = "arbitrary",
output_layernorm: bool = False,
precision: torch.dtype = torch.float32,
fuse_qkv_params: bool = True,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
):
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1043,28 +1016,43 @@ def test_export_transformer_layer( ...@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
) )
@skip_FP8 @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@skip_MXFP8 @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask():
_test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm():
_test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params():
_test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma():
_test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
_test_export_transformer_layer(activation=activation)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation( def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool,
): ):
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths. the attention mask is adjusted on-the-fly to different sequence lengths.
""" """
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1091,7 +1079,6 @@ def test_export_gpt_generation( ...@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda") ).to(device="cuda")
# "Context phase": use full input sequence length # "Context phase": use full input sequence length
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
import random import random
import pytest
import torch import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional from typing import Optional
from contextlib import nullcontext
import torch import torch
import pytest import pytest
...@@ -17,11 +15,9 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -17,11 +15,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init, fp8_model_init,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
get_cudnn_version,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
...@@ -31,7 +27,6 @@ from transformer_engine.pytorch import ( ...@@ -31,7 +27,6 @@ from transformer_engine.pytorch import (
TransformerLayer, TransformerLayer,
RMSNorm, RMSNorm,
LayerNorm, LayerNorm,
get_cpu_offload_context,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -46,13 +41,11 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -46,13 +41,11 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from utils import ModelConfig, dtype_tols from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run. # Record initial RNG state from script run.
...@@ -76,33 +69,6 @@ if NVTE_TEST_NVINSPECT_ENABLED: ...@@ -76,33 +69,6 @@ if NVTE_TEST_NVINSPECT_ENABLED:
) )
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
def is_fp8_supported(config: ModelConfig): def is_fp8_supported(config: ModelConfig):
if ( if (
config.max_seqlen_q * config.batch_size % 16 config.max_seqlen_q * config.batch_size % 16
...@@ -121,22 +87,15 @@ model_configs = { ...@@ -121,22 +87,15 @@ model_configs = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1), "large": ModelConfig(2, 128, 4, 128, num_layers=1),
} }
fp8_recipes = [ fp8_recipes = []
None, # Test non-FP8 if mxfp8_available:
recipe.MXFP8BlockScaling(), # Test default fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.Float8CurrentScaling(), # Test default if fp8_block_scaling_available:
recipe.Float8BlockScaling(), # Test default fp8_recipes.append(recipe.Float8BlockScaling())
recipe.DelayedScaling(), # Test default if fp8_available:
recipe.DelayedScaling( # Test most_recent algo fp8_recipes.append(recipe.Float8CurrentScaling())
amax_history_len=16, fp8_recipes.append(recipe.DelayedScaling())
amax_compute_algo="most_recent", fp8_recipes.append(None)
),
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_compute_algo=custom_amax_compute,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_compatible(): # bf16 requires sm_80 or higher
...@@ -160,63 +119,6 @@ def reset_global_fp8_state(): ...@@ -160,63 +119,6 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
...@@ -292,7 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci ...@@ -292,7 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
...@@ -303,16 +205,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): ...@@ -303,16 +205,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states) te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -471,12 +366,6 @@ def test_sanity_layernorm_linear( ...@@ -471,12 +366,6 @@ def test_sanity_layernorm_linear(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -505,12 +394,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -505,12 +394,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -541,12 +424,6 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -541,12 +424,6 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -586,12 +463,6 @@ def test_sanity_grouped_linear( ...@@ -586,12 +463,6 @@ def test_sanity_grouped_linear(
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -642,12 +513,6 @@ def test_sanity_layernorm_mlp( ...@@ -642,12 +513,6 @@ def test_sanity_layernorm_mlp(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -673,35 +538,23 @@ def test_sanity_layernorm_mlp( ...@@ -673,35 +538,23 @@ def test_sanity_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt( def test_sanity_gpt(
dtype, dtype,
fp8_recipe, fp8_recipe,
model, model,
skip_wgrad, skip_wgrad,
zero_centered_gamma,
bias, bias,
activation, activation,
normalization, normalization,
parallel_attention_mlp, parallel_attention_mlp,
cpu_offload,
): ):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -721,7 +574,6 @@ def test_sanity_gpt( ...@@ -721,7 +574,6 @@ def test_sanity_gpt(
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias, bias=bias,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
...@@ -729,7 +581,7 @@ def test_sanity_gpt( ...@@ -729,7 +581,7 @@ def test_sanity_gpt(
parallel_attention_mlp=parallel_attention_mlp, parallel_attention_mlp=parallel_attention_mlp,
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_gpt_126m(): def test_sanity_gpt_126m():
...@@ -746,12 +598,10 @@ def test_sanity_gpt_126m(): ...@@ -746,12 +598,10 @@ def test_sanity_gpt_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=True,
bias=True, bias=True,
activation="gelu", activation="gelu",
normalization="LayerNorm", normalization="LayerNorm",
parallel_attention_mlp=False, parallel_attention_mlp=False,
cpu_offload=False,
) )
...@@ -759,18 +609,13 @@ def test_sanity_gpt_126m(): ...@@ -759,18 +609,13 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -790,7 +635,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -790,7 +635,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="causal", self_attn_mask_type="causal",
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
...@@ -811,7 +655,6 @@ def test_sanity_bert_126m(): ...@@ -811,7 +655,6 @@ def test_sanity_bert_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm", normalization="LayerNorm",
) )
...@@ -820,18 +663,13 @@ def test_sanity_bert_126m(): ...@@ -820,18 +663,13 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -852,7 +690,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no ...@@ -852,7 +690,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
) )
...@@ -872,7 +709,6 @@ def test_sanity_T5_126m(): ...@@ -872,7 +709,6 @@ def test_sanity_T5_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm", normalization="LayerNorm",
) )
...@@ -885,12 +721,6 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -885,12 +721,6 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -917,17 +747,10 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -917,17 +747,10 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_drop_path(dtype, fp8_recipe, model):
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -951,7 +774,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -951,7 +774,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device="cuda", device="cuda",
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -962,12 +785,6 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -962,12 +785,6 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -991,26 +808,17 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -991,26 +808,17 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device="cuda", device="cuda",
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -1030,7 +838,6 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -1030,7 +838,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True, fuse_qkv_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
device="cuda", device="cuda",
...@@ -1039,52 +846,6 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -1039,52 +846,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast(): def test_model_multiple_cast():
a = torch.zeros((16, 16), device="cuda") a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32) m = Linear(16, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment