# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ This file contains tests for exporting TransformerEngine models to ONNX. The purpose of these tests is validation that TE models are converted to their correct ONNX representation. Toward this end, each test captures the output of a TE module forward pass, converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and validate the output against TE's output. Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented using custom ORT operations. To run many repetitive tests use pytest-loop: $ python3 -m pip install pytest-loop $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm For reproducability use: torch.manual_seed(0) """ import os import tempfile import pytest import warnings import numpy as np import onnxruntime as ort import torch from torch import nn as nn from typing import Optional, Union, Tuple, List import transformer_engine.pytorch as te from transformer_engine.common import recipe import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.module.base import get_workspace import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.softmax as softmax_defs from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.fp8 import FP8GlobalStateManager # Global test configuration knobs. # Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0"))) if SAVE_TEST_IO: from polygraphy.json import save_json from polygraphy.comparator import RunResults # The directory where generated ONNX test models are stored. NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR') NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models") # The directory where this file is stored. TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) # ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. TRILU_OPSET = 14 # Opset used in the ONNX files generated by the tests. OPSET = 17 assert OPSET >= TRILU_OPSET # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @pytest.fixture() def seed_default_rng(): """Reseed the PRNG for test reproducibility""" torch.random.seed() @pytest.fixture() def set_max_seq_len(max_seq_len=128): """Set the maximum sequence length that can be used for attention masking""" os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" def create_fp8_recipe(): return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) def do_export( model: torch.nn.Module, inp: torch.Tensor, fname: str, use_fp8: bool=True, opset: int=OPSET, input_names: List[str]=None, output_names: List[str]=None, dynamic_axes: List[str]=None ): """Export to ONNX""" fp8_recipe = create_fp8_recipe() input_names = input_names or ["input"] output_names = output_names or ["output"] with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): warnings.filterwarnings( action='ignore', category=torch.jit.TracerWarning, module=r'.*' ) model.cuda().eval() os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) assert len(inps) == len(input_names) inds_to_del = [i for i in range(len(inps)) if inps[i] is None] input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del] with te.onnx_export(True): torch.onnx.export( model, inps, fname, verbose=True, dynamic_axes=dynamic_axes, opset_version=opset, input_names=input_names, output_names=output_names, do_constant_folding=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) def to_numpy(tensor): if isinstance(tensor, torch.Tensor): if tensor.dtype == torch.bfloat16: tensor = tensor.type(torch.float32) tensor = tensor.detach().cpu().numpy() return tensor def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): """Initialize the FP8 quantization scales in module""" NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM module.fp8_init(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") * scale def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): """Transformer Engine forward propagation.""" fp8_recipe = create_fp8_recipe() with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) if not isinstance(te_outputs, tuple): te_outputs = (te_outputs,) return te_outputs def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname): """ Compare ORT and TE outputs.""" assert len(onnx_outputs) == len(te_outputs) # Compare ORT and PyTorch outputs. for onnx_output, te_output in zip(onnx_outputs, te_outputs): # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) te_output = to_numpy(te_output) onnx_output = to_numpy(onnx_output) ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) mismatches = ac.nonzero() mismatched_ids = [loc for loc in zip(*mismatches)] if mismatched_ids: # Log some information in case of error. print("*" * 100) nb_errors = len(mismatched_ids) nb_vals = min(nb_errors, max_errors_printed) print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") print(f"Showing first {nb_vals} errors (ONNX -- TE):") abs_err = np.abs(onnx_output - te_output) errors = abs_err[mismatches] for loc in mismatched_ids[:nb_vals]: ref = te_output[loc] print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}") print(f"Max error: {np.max(errors)}") if nb_errors > allow_cnt_errors: raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") def serialize_inputs_outputs( fname: str, inputs: Union[Tuple[torch.Tensor], torch.Tensor], te_outputs: List[torch.Tensor], input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, ): if not SAVE_TEST_IO: return fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) input_names = input_names or ["input"] output_names = output_names or ["output"] inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) named_inputs = zip(input_names, inputs) input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] json_fname = fname[:-len(".onnx")] + "_inputs.json" save_json(input_data, json_fname, description="custom input data") json_fname = fname[:-len(".onnx")] + "_output.json" named_outputs = zip(output_names, te_outputs) output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} custom_outputs = RunResults() custom_outputs.add([output_data], runner_name="custom_runner") custom_outputs.save(json_fname) def validate_result( fname: str, inps: Union[Tuple[torch.Tensor], torch.Tensor], model: torch.nn.Module, atol: float=1.e-8, # np.isclose default atol rtol: float=1.e-5, # np.isclose default rtol max_errors_printed: int=10, is_fp8: bool=False, allow_cnt_errors: int=0, input_names: List[str]=None, output_names: List[str]=None, te_outputs: List[torch.Tensor]=None, ): """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX representation using ONNX Runtime (ORT) and ensure they are close. The purpose of the output comparison is to validate that TE models are converted to their correct ONNX representation by testing that TE and ORT outputs match within some small threshold (allowing for finite precision errors). Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, a very small number (0-3) of outliers. This is fine to do because these outliers are due to small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX representation (the tests assume both ORT or TE kernels are correct). Argument `te_outputs` can be used to provide pre-computed TE outputs. """ def create_ort_session(fname: str, is_fp8: bool): def load_custom_ops(session_opts: ort.SessionOptions): """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" if not os.path.exists(ORT_CUSTOM_OPS_LIB): raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) print("registered custom FP8 Q/DQ ops!") """Create an ONNX Runtime session for validation.""" kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']} if is_fp8: sess_options = ort.SessionOptions() load_custom_ops(sess_options) kwargs["sess_options"] = sess_options s = ort.InferenceSession(fname, **kwargs) return s def create_ort_input_dict(session, inputs): inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) input_names = [x.name for x in session.get_inputs()] inps = [to_numpy(x) for x in inputs if x is not None] inp_dict = dict(zip(input_names, inps)) return inp_dict input_names = input_names or ["input"] output_names = output_names or ["output"] # Run ORT session and TE model. fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) if not te_outputs: te_outputs = te_infer(model, inps, is_fp8) ort_s = create_ort_session(fname, is_fp8) input_feed = create_ort_input_dict(ort_s, inps) onnx_outputs = ort_s.run(None, input_feed=input_feed) compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname) def create_meta(scale_factor: float, size: int=1): meta = tex.FP8TensorMeta() meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor return meta def dtype2str(dtype: torch.dtype, fake_bf16_io=False): if fake_bf16_io: assert dtype == torch.bfloat16 return "_fake_bf16" return { torch.float32: "_fp32", torch.float16: "_fp16", torch.bfloat16: "_bf16", }[dtype] def as_te_type(dtype: torch.dtype): return { torch.float32: tex.DType.kFloat32, torch.float16: tex.DType.kFloat16, torch.bfloat16: tex.DType.kBFloat16, }[dtype] def get_attn_mask_str(use_mask, attn_mask_type): # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. if attn_mask_type is None: return "_mask" if use_mask else "_no-mask" attn_mask_str = "_padding-no-mask" attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str return attn_mask_str """ Tests cases begin here. """ @skip_FP8 @pytest.mark.parametrize("scale_factor", [1, 224]) @pytest.mark.parametrize( "precision, atol", [ [torch.float32, 1e-7], [torch.float16, 1e-7], [torch.bfloat16, 5e-3], ["fake-torch.bfloat16", 5e-3], ]) def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision class TestFP8_QDQ(nn.Module): def __init__(self, fake_bf16_io): super().__init__() self.fp8_tensor = 0 self.meta = create_meta(scale_factor) self.highprec_type = as_te_type(precision) self.fp8_type = tex.DType.kFloat8E4M3 self.fake_bf16_io = fake_bf16_io def forward(self, inp): ret = cast_to_fp8( inp, self.meta, self.fp8_tensor, self.fp8_type) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 inp = torch.randn(hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_QDQ(fake_bf16_io) do_export(model, inp, fname) te_outputs = te_infer(model, inp, is_fp8=True) serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize( "precision, atol", [ [torch.float32, 1e-5], [torch.float16, 1e-5], [torch.bfloat16, 5e-3], ["fake-torch.bfloat16", 5e-3] ]) def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision class TestFP8_Gelu(nn.Module): def __init__(self, fake_bf16_io): super().__init__() self.fp8_tensor = 0 self.meta = create_meta(scale_factor) self.highprec_type = as_te_type(precision) self.fp8_type = tex.DType.kFloat8E4M3 self.fake_bf16_io = fake_bf16_io def forward(self, inp): ret = gelu( inp, self.meta, self.fp8_tensor, self.fp8_type) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 inp = torch.randn(hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_Gelu(fake_bf16_io) do_export(model, inp, fname) te_outputs = te_infer(model, inp, is_fp8=True) serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs) @pytest.mark.parametrize("scale_factors", [(224, 224,), ]) @pytest.mark.parametrize( "precision, use_fp8, use_bias, use_gelu", [ (torch.float32, False, False, False), (torch.float16, False, False, False), (torch.bfloat16, False, False, False), (torch.float32, False, True, False), (torch.float16, False, True, False), (torch.bfloat16, False, True, False), (torch.float32, False, True, True), (torch.float16, False, True, True), (torch.bfloat16, False, True, True), # For FP8 GEMM GeLU is not used. (torch.float32, True, False, False), (torch.float16, True, False, False), (torch.bfloat16, True, False, False), # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) (torch.float16, True, True, False), (torch.bfloat16, True, True, False), ]) def test_export_gemm( seed_default_rng, precision, # Precision of inputs, weights, output and bias use_fp8, use_bias, use_gelu, scale_factors ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) class TestFP8_GEMM(nn.Module): def __init__(self, precision, use_bias, gelu, scale_factors): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT nb_inp_scales, nb_weight_scales = 1, out_features act_scale_factor, weight_scale_factor = scale_factors self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) bias_size = nb_weight_scales self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision def forward(self, inp, weight): inp_fp8 = cast_to_fp8( inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type) ret = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, self.inp_type, inp_fp8, self.meta_inp.scale_inv, self.fp8_tensor_inp, self.weights_type, self.outp_type, get_workspace(), bias=self.bias, use_bias=self.use_bias, use_split_accumulator=False) return ret class Test_GEMM(nn.Module): def __init__(self, precision, use_bias=False, gelu=False): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision bias_size = out_features self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") def forward(self, inp, weight): outp_type = self.precision # note: due to logic in lines 104:116 and L129 in cpp_extensions.py # it appears either bias OR gelu can be activated, not both ret, _, _ = gemm( weight, inp, outp_type, get_workspace(), # test bias bias=self.bias, use_bias=self.use_bias, # test gelu gelu=self.gelu, gelu_input=self.gelu_input, grad=False, # only True for backward pass accumulate=False, ) return ret # If gelu is applied then bias must be added, as defined by TE kernel. if use_gelu: assert use_bias # Set dimensions (these are arbitrary). out_features = 128 hidden_size = 256 in_features = 64 inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) weight = torch.randn(out_features, in_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" gelu_str = "_gelu" if use_gelu else "" high_prec_str = dtype2str(precision) fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" input_names = ['input', 'weight'] if use_fp8: model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) if precision != torch.bfloat16: validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True, input_names=input_names, te_outputs=te_outputs) else: model = Test_GEMM(precision, use_bias, use_gelu) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) if precision != torch.bfloat16: validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, input_names=input_names, te_outputs=te_outputs) @pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize( "use_fp8, precision, atol", [ [False, torch.float32, 1e-7], [False, torch.float16, 1e-7], [False, torch.bfloat16, 1e-7], [False, "fake-torch.bfloat16", 1e-7], [True, torch.float32, 1e-7], [True, torch.float16, 1e-7], [True, torch.bfloat16, 1e-2], [True, "fake-torch.bfloat16", 1e-2] ]) def test_export_layernorm( seed_default_rng, use_fp8: bool, scale_factor: float, precision: torch.dtype, zero_centered_gamma: bool, atol: float ): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). inp_shape = [64, 32] class Test_Layernorm(nn.Module): def __init__(self) -> None: super().__init__() eps = 1e-6 # An arbitrary small value dtype = torch.float if fake_bf16_io else precision self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=False).eval().cuda() def forward(self, inp): ret = self.ln(inp) return ret class TestFP8_Layernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) self.bias = torch.zeros(*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 def forward(self, inp): ret = texcpp.layernorm_fwd_fp8_inf( inp, self.weight, self.bias, self.eps, self.meta, self.fp8_tensor, self.fp8_type, zero_centered_gamma) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)) if fake_bf16_io: ret = ret.type(torch.float32) return ret inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" do_export(model, inp, fname, use_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result( fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) @pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize( "use_fp8, precision, atol", [ [False, torch.float32, 1e-7], [False, torch.float16, 1e-7], [False, torch.bfloat16, 1e-7], [False, "fake-torch.bfloat16", 1e-7], [True, torch.float32, 1e-7], [True, torch.float16, 1e-7], [True, torch.bfloat16, 1e-2], [True, "fake-torch.bfloat16", 1e-2] ]) def test_export_rmsnorm( seed_default_rng, use_fp8: bool, scale_factor: float, precision: torch.dtype, atol: float ): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). inp_shape = [64, 32] class Test_RMSnorm(nn.Module): def __init__(self) -> None: super().__init__() eps = 1e-6 # An arbitrary small value dtype = torch.float if fake_bf16_io else precision self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype).eval().cuda() def forward(self, inp): ret = self.ln(inp) return ret class TestFP8_RMSnorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 def forward(self, inp): ret = texcpp.rmsnorm_fwd_fp8_inf( inp, self.weight, self.eps, self.meta, self.fp8_tensor, self.fp8_type, False) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)) if fake_bf16_io: ret = ret.type(torch.float32) return ret inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm() high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" do_export(model, inp, fname, use_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) if fake_bf16_io or precision != torch.bfloat16: validate_result( fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize("softmax_fn", [ softmax_defs.ScaledUpperTriangMaskedSoftmax, softmax_defs.ScaledMaskedSoftmax, softmax_defs.ScaledSoftmax, te.softmax.FusedScaleMaskSoftmax, ]) # Softmax kernel only supports FP16 or BF16! @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision): class Test_Softmax(nn.Module): def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False): super().__init__() self.softmax_fn = softmax_fn self.scale = 8 # arbitrary value self.mask_inp = mask_inp self.fused_scaled_softmax = None self.fake_bf16_io = fake_bf16_io if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax: self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( mask_func=te.utils.attention_mask_func, softmax_in_fp32=True, ) def forward(self, inp, mask): if self.fake_bf16_io: inp = inp.type(torch.bfloat16) if self.fused_scaled_softmax: ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale) else: if self.mask_inp: ret = self.softmax_fn.apply(inp, mask, self.scale) else: ret = self.softmax_fn.apply(inp, self.scale) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret fake_bf16_io = precision == "fake-torch.bfloat16" precision = torch.bfloat16 if fake_bf16_io else precision # Set dimensions (these are arbitrary). batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 mask = None input_names = ["input", "mask"] inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax: inp_shape = [batch_size, seq_len_q, seq_len_k] kernel_str = "ScaledUpperTriangMaskedSoftmax" model = Test_Softmax(softmax_fn, fake_bf16_io) elif softmax_fn == softmax_defs.ScaledMaskedSoftmax: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision) mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) kernel_str = "ScaledMaskedSoftmax" model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True) elif softmax_fn == softmax_defs.ScaledSoftmax: kernel_str = "ScaledSoftmax" model = Test_Softmax(softmax_fn, fake_bf16_io) elif softmax_fn == te.softmax.FusedScaleMaskSoftmax: kernel_str = "TorchSoftmax" model = Test_Softmax(softmax_fn, fake_bf16_io) input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) fname = f"{kernel_str}{high_prec_str}.onnx" inp = (input_tensor, mask) dynamic_axes = {} if mask is not None: dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}} do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes) te_outputs = te_infer(model, inp, is_fp8=False) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if fake_bf16_io or precision != torch.bfloat16: atol = 5e-2 if fake_bf16_io else 1e-3 validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs) # Test dynamically generated softmax mask. # Softmax kernel only supports FP16 or BF16! @skip_FP8 @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) def test_softmax_mask_fn(seed_default_rng, precision): fake_bf16_io = precision == "fake-torch.bfloat16" # reset precision to torch.bfloat16 after capturing fake BF16 mode precision = torch.bfloat16 if fake_bf16_io else precision class Test_Softmax(nn.Module): def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool): super().__init__() self.scale = 1 # arbitrary value self.fake_bf16_io = fake_bf16_io if use_default_te_mask_fn: os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0" else: os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}" # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax # even when is_in_onnx_export_mode()==False. os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( mask_func=te.utils.attention_mask_func, softmax_in_fp32=True, ) def forward(self, inp, mask): if self.fake_bf16_io: inp = inp.type(torch.bfloat16) ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale) if self.fake_bf16_io: ret = ret.type(torch.float) return ret # Set dimensions (these are arbitrary). mask = None batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 assert seq_len_q == seq_len_k # This is a causal (TRILU) mask inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] input_tensor = torch.randn( *inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision) inp = (input_tensor, mask) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) # Compare the outputs of TE when using the default softmax mask # to the TE outputs produced when using the ONNX-compatible causal mask. # This verifies that _get_onnx_export_causal_mask generates a correct mask. model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io) te_outputs_default_mask = te_infer(model, inp, is_fp8=True) with te.onnx_export(True): # ONNX export mode forces use of the ONNX-compatible causal mask. model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io) te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True) compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask, atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking") # Compare the outputs of TE when using the default softmax mask # to the ORT ONNX outputs produced when using the ONNX-compatible causal mask. input_names = ["input", "mask"] kernel_str = "FusedScaleMaskSoftmax" fname = f"{kernel_str}{high_prec_str}.onnx" do_export(model, inp, fname, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names) if fake_bf16_io or precision != torch.bfloat16: atol = 1e-2 if fake_bf16_io else 1e-3 validate_result( fname, inp, model_onnx_mask, atol=atol, input_names=input_names, te_outputs=te_outputs_default_mask) @pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, False), (torch.float16, True), # Todo: cannot configure BF16 when bias is disabled (ORT issue?) (torch.bfloat16, False), # Todo: cannot configure BF16 when bias is enabled (ORT issue?) (torch.bfloat16, True), ]) def test_export_linear( seed_default_rng, scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, precision: torch.dtype ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 class Test_Linear(nn.Module): def __init__(self, in_features, out_features, use_bias, return_bias, precision ): super().__init__() self.linear = te.Linear( in_features, out_features, bias=use_bias, return_bias=return_bias, params_dtype=precision ) def forward(self, inp): ret = self.linear(inp) return ret inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): model = Test_Linear( in_features, out_features, use_bias, return_bias, precision ).to(device='cuda') if use_fp8: set_layer_scale(model.linear, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) if precision in (torch.bfloat16, ): return if not use_fp8: validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) else: validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs) @pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, True), (torch.float16, False), (torch.bfloat16, True), (torch.bfloat16, False), ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("normalization", all_normalizations) def test_export_layernorm_linear( seed_default_rng, scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, return_layernorm_output: bool, precision: torch.dtype, zero_centered_gamma: bool, normalization: str, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if normalization == "RMSNorm" and zero_centered_gamma: pytest.skip("RMSNorm does not support zero_centered_gamma yet!") # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): model = te.LayerNormLinear( hidden_size, 3 * hidden_size, bias=use_bias, return_bias=return_bias, return_layernorm_output=return_layernorm_output, params_dtype=precision, zero_centered_gamma=zero_centered_gamma, normalization=normalization, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) if precision in (torch.bfloat16, ): return if not use_fp8: validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) elif precision != torch.bfloat16: validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs) @pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, True), (torch.float16, False), (torch.bfloat16, True), (torch.bfloat16, False), ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("activation", supported_activations) @pytest.mark.parametrize("normalization", all_normalizations) def test_export_layernorm_mlp( seed_default_rng, scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, return_layernorm_output: bool, precision: torch.dtype, zero_centered_gamma: bool, activation: str, normalization: str, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if normalization == "RMSNorm" and zero_centered_gamma: pytest.skip("RMSNorm does not support zero_centered_gamma yet!") # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 ffn_hidden_size = 256 inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" with te.fp8_autocast(enabled=use_fp8): model = te.LayerNormMLP( hidden_size, ffn_hidden_size, bias=use_bias, return_bias=return_bias, return_layernorm_output=return_layernorm_output, params_dtype=precision, zero_centered_gamma=zero_centered_gamma, activation=activation, normalization=normalization, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=2) do_export(model, inp, fname, use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs) if precision in (torch.bfloat16, ): return atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ (torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) ]) def test_export_core_attention( seed_default_rng, set_max_seq_len, precision: torch.dtype, use_mask: bool, attn_mask_type: str, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"] attention_mask = None if use_mask: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, ).to(device='cuda') do_export(model, inp, fname, input_names=input_names, use_fp8=True) te_outputs = te_infer(model, inp, is_fp8=True) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16, ): return validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs) test_configs_multihead_attention = [ #"use_mask, attn_mask_type" (False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax ] test_configs_attention_type = [ #"input_layernorm, attention_type, fuse_qkv_params" (True, "self", True), (False, "self", True), (True, "self", False), (False, "self", False), (True, "cross", True), (False, "cross", True), (True, "cross", False), (False, "cross", False), ] @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type) def test_export_multihead_attention( seed_default_rng, set_max_seq_len, use_fp8: bool, use_mask: bool, attn_mask_type: str, precision: torch.dtype, return_layernorm_output: bool, input_layernorm: bool, attention_type: str, fuse_qkv_params: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) hidden_size = 256 sequence_length = 128 batch_size = 4 num_attention_heads = 32 kv_channels = 8 attention_dropout = 0.1 layernorm_epsilon = 1e-5 init_method = output_layer_init_method = get_default_init_method() attention_args = ( hidden_size, num_attention_heads, kv_channels, attention_dropout, layernorm_epsilon, init_method, output_layer_init_method, ) hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) encoder_output = None if attention_type == "cross": encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") fp8_str = "_fp8" if use_fp8 else "" dtype_str = dtype2str(precision) attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) input_ln_str = "_input-ln" if input_layernorm else "" fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" model = te.MultiheadAttention( *attention_args, params_dtype=precision, return_layernorm_output=return_layernorm_output, input_layernorm=input_layernorm, attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, return_bias=True, ).to(device='cuda') inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type) input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"] output_names=["attention_output", "attention_bias"] do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, "attention_output": {0: "seq", 1:"bs"}}) te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names) if precision in (torch.bfloat16, ): return if not use_fp8: validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names, output_names=output_names, te_outputs=te_outputs) else: validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, allow_cnt_errors=3, te_outputs=te_outputs) # In GPT generative phase (inference) the input sequence is smaller than the maximum # allowed sequence length and we want to test this condition. # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). is_generative_phase = (attn_mask_type == "causal" and attention_type == "self") if is_generative_phase: seq_len_offset = 8 hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda") inp_generative = (hidden_states_generative, attention_mask, encoder_output) if not use_fp8: validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names) else: validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, allow_cnt_errors=3) @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("output_layernorm", [ #True, # TO DO: handle this False ]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("activation", supported_activations) def test_export_transformer_layer( seed_default_rng, set_max_seq_len, use_fp8: bool, use_mask: bool, attn_mask_type: str, output_layernorm: bool, precision: torch.dtype, fuse_qkv_params: bool, zero_centered_gamma: bool, activation: str, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Layer configuration hidden_size = 64 sequence_length = 128 batch_size = 1 ffn_hidden_size = 256 num_attention_heads = 4 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_names = ["input", "attention_mask", "self_attn_mask_type"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) inp = (input_tensor, attention_mask, attn_mask_type) fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" high_prec_str = dtype2str(precision) attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx" model = te.TransformerLayer( hidden_size, ffn_hidden_size, num_attention_heads, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma, activation=activation).to(device='cuda') do_export(model, inp, fname, use_fp8, input_names=input_names) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16, ): return atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs) @pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("ln_scale_factor", [448*2]) @pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_gemm_layernorm( seed_default_rng, use_fp8: bool, ln_scale_factor: float, gemm_scale_factors: Tuple[float, float], precision: torch.dtype, zero_centered_gamma: bool ): """This is a regression test for testing that all LN inputs have the same type. The test sets up GEMM with FP32 output which feeds into an LN that is configured with FP16 or BF16 weights and bias. """ # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) class TestFP8_GemmLayernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(ln_scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 self.gemm = TestFP8_GEMM( precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors) def forward(self, inp, weight): x = self.gemm(inp, weight) x = texcpp.layernorm_fwd_fp8_inf( x, self.weight, self.bias, self.eps, self.meta, self.fp8_tensor, self.fp8_type, zero_centered_gamma) x = cast_from_fp8( x, self.meta, self.fp8_tensor, self.fp8_type, tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) return x out_features = 128 hidden_size = 128 in_features = 128 class TestFP8_GEMM(nn.Module): def __init__(self, precision, use_bias, gelu, scale_factors): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT nb_inp_scales, nb_weight_scales = 1, out_features act_scale_factor, weight_scale_factor = scale_factors self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) bias_size = nb_weight_scales self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision def forward(self, inp, weight): inp_fp8 = cast_to_fp8( inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type) ret = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, self.inp_type, inp_fp8, self.meta_inp.scale_inv, self.fp8_tensor_inp, self.weights_type, self.outp_type, get_workspace(), bias=self.bias, use_bias=self.use_bias, use_split_accumulator=False) return ret inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") model = TestFP8_GemmLayernorm() high_prec_str = dtype2str(precision) fp8_str = f"_fp8" if use_fp8 else "" fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" input_names = ['input', 'weight'] do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) if precision not in (torch.bfloat16, ): validate_result( fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, input_names=input_names, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize("use_fp8", [True, False]) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("zero_centered_gamma", [True]) def test_export_gpt_generation( seed_default_rng, set_max_seq_len, use_fp8: bool, precision: torch.dtype, zero_centered_gamma: bool, ): """Test that the ONNX model can correctly handle inputs with different shapes and that the attention mask it adjusted on-the-fly to different sequence lengths. """ # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Layer configuration hidden_size = 64 sequence_length = 128 batch_size = 1 ffn_hidden_size = 256 num_attention_heads = 4 attention_mask = None use_mask = True attn_mask_type = "causal" fuse_qkv_params = True output_layernorm = False fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" high_prec_str = dtype2str(precision) attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" model = te.TransformerLayer( hidden_size, ffn_hidden_size, num_attention_heads, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma).to(device='cuda') # "Context phase": use full input sequence length input_names = ["input", "attention_mask", "self_attn_mask_type"] output_names = ["output"] input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") inp = (input_tensor, None, attn_mask_type) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"input": {0: "seq", 1:"bs"}, "output": {0: "seq", 1:"bs"}, }) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names) if precision not in (torch.bfloat16, ): validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs) # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. sequence_length = 1 if not use_fp8 else 8 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") inp = (input_tensor, attention_mask) te_outputs = te_infer(model, inp, is_fp8=use_fp8) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision not in (torch.bfloat16, ): validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs) @pytest.mark.parametrize("enabled", [True, False]) def test_export_ctx_manager(enabled): assert is_in_onnx_export_mode() == False with te.onnx_export(enabled): assert is_in_onnx_export_mode() == enabled assert is_in_onnx_export_mode() == False