"git@developer.sourcefind.cn:modelzoo/qwen3-omni_vllm.git" did not exist on "21c5b5da7fe1db87ef2df762dbafd5c33ee1d5f2"
Unverified Commit bcbd4be0 authored by tcherckez-nvidia's avatar tcherckez-nvidia Committed by GitHub
Browse files

Fix FlashAttention tests (#99)


Signed-off-by: default avatarTal Cherckez <tcherckez@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f56e4fd0
......@@ -26,3 +26,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
......@@ -9,4 +9,4 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -32,7 +32,7 @@ from transformer_engine.pytorch.module import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode
# Global test configuration knobs.
......@@ -89,15 +89,16 @@ def do_export(
os.makedirs(TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
torch.onnx.export(model,
inps,
fname,
verbose=False,
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
with te.onnx_export(True):
torch.onnx.export(model,
inps,
fname,
verbose=False,
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
def to_numpy(tensor):
......@@ -1003,3 +1004,10 @@ def test_export_transformer_layer(
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
assert is_in_onnx_export_mode() == False
with te.onnx_export(enabled):
assert is_in_onnx_export_mode() == enabled
assert is_in_onnx_export_mode() == False
......@@ -10,6 +10,7 @@ from .module import LayerNorm
from .transformer import DotProductAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .export import onnx_export
from .distributed import checkpoint
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Export utilities for TransformerEngine"""
from contextlib import contextmanager
_IN_ONNX_EXPORT_MODE = False
@contextmanager
def onnx_export(
enabled: bool = False,
) -> None:
"""
Context manager for exporting to ONNX.
.. code-block:: python
with onnx_export(enabled=True):
torch.onnx.export(model)
----------
enabled: bool, default = `False`
whether or not to enable export
"""
global _IN_ONNX_EXPORT_MODE
onnx_export_state = (_IN_ONNX_EXPORT_MODE)
try:
_IN_ONNX_EXPORT_MODE = enabled
yield
finally:
_IN_ONNX_EXPORT_MODE = onnx_export_state
def is_in_onnx_export_mode() -> bool:
"""Returns True if onnx export mode is enabled, False otherwise."""
return _IN_ONNX_EXPORT_MODE
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""FP8 utilies for TransformerEngine"""
"""FP8 utilities for TransformerEngine"""
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
......
......@@ -41,6 +41,7 @@ from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = version("flash-attn")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......@@ -442,6 +443,9 @@ class DotProductAttention(torch.nn.Module):
):
use_flash_attention = False
if is_in_onnx_export_mode():
use_flash_attention = False
if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment