Unverified Commit aac74427 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Prune L0 unit test (#1999)



* Add verbosity only for failing tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Prune some tests and preinit recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Prune further tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix multitensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix a100
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c6c1f50e
......@@ -26,30 +26,30 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -6,7 +6,7 @@ import math
import os
import sys
import pathlib
from typing import Any, Dict, List, Tuple, Union, Optional
from typing import Any, Dict, Tuple, Union
import pytest
import torch
......@@ -20,10 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
)
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
......@@ -54,7 +52,6 @@ from utils import (
reset_rng_states,
ModelConfig,
dtype_tols,
logging_context,
get_available_attention_backends,
)
......
......@@ -14,15 +14,12 @@ from transformer_engine.pytorch.attention.dot_product_attention import _attentio
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
]
fp8_recipes = [None]
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
......@@ -129,12 +126,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
if fp8_recipe and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
......
......@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from dataclasses import dataclass
import itertools
from typing import Iterable, List, Tuple, Union
from typing import Iterable, List, Union
import pytest
import torch
......@@ -26,11 +24,9 @@ from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
......@@ -39,12 +35,14 @@ model_configs = {
"small": ModelConfig(32, 2, 2, 32),
}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
......@@ -277,35 +275,27 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
def test_make_graphed_callables(
*,
module: str,
model_config: str = "small",
num_layers: int = 3,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
fp8_weight_caching: bool = False,
) -> None:
# Skip invalid configurations.
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
fp8 = fp8_recipe is not None
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
......@@ -336,7 +326,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
......@@ -352,7 +341,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables(
module=module,
dtype=torch.float32,
fp8=True,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
......
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from itertools import product
import copy
from contextlib import nullcontext
......
......@@ -2,8 +2,7 @@
#
# See LICENSE for license information.
import torch
import math
from typing import Optional, Dict
from typing import Optional
from transformer_engine.pytorch.router import (
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
......
......@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel):
......
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
......@@ -37,23 +36,20 @@ from transformer_engine.pytorch import (
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -103,18 +99,21 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
):
available_backends, _, fused_attn_backends = get_available_attention_backends(
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -571,14 +570,8 @@ def _test_e2e_selective_recompute(
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -687,14 +680,8 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -1263,8 +1250,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
# Should be bit-wise match
for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
......@@ -1276,12 +1263,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
......@@ -1649,14 +1631,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation
dtype, bs, model, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
......@@ -1665,7 +1645,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
......@@ -1677,7 +1656,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
......@@ -1687,7 +1665,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
......@@ -1802,14 +1779,8 @@ def test_grouped_linear_accuracy(
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -1904,14 +1875,8 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
......@@ -2114,14 +2079,8 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2189,14 +2148,8 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
......@@ -2410,14 +2363,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -2645,9 +2592,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
def test_fp8_grouped_gemm(shape, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
......
......@@ -27,7 +27,6 @@ import warnings
import numpy as np
import onnxruntime as ort
import torch
import random
from torch import nn as nn
from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
......@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [
None,
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......@@ -369,14 +367,6 @@ def validate_result(
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io:
assert dtype == torch.bfloat16
......@@ -413,36 +403,12 @@ Test cases begin here.
"""
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
def _test_export_linear(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
precision: torch.dtype = torch.float32,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -498,32 +464,28 @@ def test_export_linear(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize(
"precision",
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias):
_test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
_test_export_linear(return_bias=return_bias)
def _test_export_layernorm(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
......@@ -564,39 +526,31 @@ def test_export_layernorm(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_zero_centered_gamma(seed_default_rng):
_test_export_layernorm(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
def test_export_layernorm_normalization(seed_default_rng, normalization):
_test_export_layernorm(normalization=normalization)
def _test_export_layernorm_linear(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
activation: str,
normalization: str,
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng):
_test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
_test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng):
_test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng):
_test_export_layernorm_linear(return_bias=True)
def _test_export_layernorm_mlp(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
normalization: str = all_normalizations[0],
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
_test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng):
_test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng):
_test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
_test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
......@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
],
)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
......@@ -777,11 +764,6 @@ def test_export_core_attention(
)
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
......@@ -795,31 +777,14 @@ test_configs_attention_type = [
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
def _test_export_multihead_attention(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
precision: torch.dtype = torch.float32,
input_layernorm: bool = True,
attention_type: str = "self",
fuse_qkv_params: bool = True,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256
sequence_length = 128
batch_size = 4
......@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method,
output_layer_init_method,
)
attn_mask_type = "arbitrary" if use_mask else "no_mask"
hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
......@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
return_layernorm_output=False,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
......@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
def test_export_multihead_attention_recipe(fp8_recipe, precision):
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask():
_test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm():
_test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn():
_test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params():
_test_export_multihead_attention(fuse_qkv_params=False)
def _test_export_transformer_layer(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
attn_mask_type: str = "arbitrary",
output_layernorm: bool = False,
precision: torch.dtype = torch.float32,
fuse_qkv_params: bool = True,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
):
# Layer configuration
hidden_size = 64
sequence_length = 128
......@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
)
@skip_FP8
@skip_MXFP8
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask():
_test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm():
_test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params():
_test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma():
_test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
_test_export_transformer_layer(activation=activation)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
......@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length
......
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
import random
import pytest
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
......
......@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext
import torch
import pytest
......@@ -17,11 +15,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
get_cudnn_version,
)
from transformer_engine.pytorch import (
LayerNormLinear,
......@@ -31,7 +27,6 @@ from transformer_engine.pytorch import (
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......@@ -46,13 +41,11 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import ModelConfig, dtype_tols
from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
......@@ -76,33 +69,6 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
def is_fp8_supported(config: ModelConfig):
if (
config.max_seqlen_q * config.batch_size % 16
......@@ -121,22 +87,15 @@ model_configs = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
fp8_recipes = [
None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default
recipe.Float8BlockScaling(), # Test default
recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_compute_algo=custom_amax_compute,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
......@@ -160,63 +119,6 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
......@@ -292,7 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
......@@ -303,16 +205,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if skip_wgrad:
_disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -471,12 +366,6 @@ def test_sanity_layernorm_linear(
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -505,12 +394,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -541,12 +424,6 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -586,12 +463,6 @@ def test_sanity_grouped_linear(
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -642,12 +513,6 @@ def test_sanity_layernorm_mlp(
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -673,35 +538,23 @@ def test_sanity_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
bias,
activation,
normalization,
parallel_attention_mlp,
cpu_offload,
):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -721,7 +574,6 @@ def test_sanity_gpt(
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
activation=activation,
normalization=normalization,
......@@ -729,7 +581,7 @@ def test_sanity_gpt(
parallel_attention_mlp=parallel_attention_mlp,
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_gpt_126m():
......@@ -746,12 +598,10 @@ def test_sanity_gpt_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=True,
bias=True,
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)
......@@ -759,18 +609,13 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -790,7 +635,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype=dtype,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="causal",
normalization=normalization,
device="cuda",
......@@ -811,7 +655,6 @@ def test_sanity_bert_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
......@@ -820,18 +663,13 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -852,7 +690,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
device="cuda",
)
......@@ -872,7 +709,6 @@ def test_sanity_T5_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
......@@ -885,12 +721,6 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -917,17 +747,10 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -951,7 +774,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
_test_sanity_e2e(block, dtype, config, fp8_recipe, False)
@pytest.mark.parametrize("dtype", param_types)
......@@ -962,12 +785,6 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -991,26 +808,17 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -1030,7 +838,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
device="cuda",
......@@ -1039,52 +846,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment