# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ This file contains tests for exporting TransformerEngine models to ONNX. The purpose of these tests is validation that TE models are converted to their correct ONNX representation. Toward this end, each test captures the output of a TE module forward pass, converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and validate the output against TE's output. Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented using custom ORT operations. To run many repetitive tests use pytest-loop: $ python3 -m pip install pytest-loop $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm For reproducability use: torch.manual_seed(0) """ import os import tempfile import pytest import warnings import numpy as np import onnxruntime as ort import torch from torch import nn as nn from typing import Union, Tuple import transformer_engine.pytorch as te from transformer_engine.common import recipe import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.module import get_workspace import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.softmax as softmax_defs from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.fp8 import is_fp8_available # Global test configuration knobs. # Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). SAVE_TEST_IO = False if SAVE_TEST_IO: from polygraphy.json import save_json from polygraphy.comparator import RunResults # The directory where generated ONNX test models are stored. NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR') NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models") # The directory where this file is stored. TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) # ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. TRILU_OPSET = 14 # Opset used in the ONNX files generated by the tests. OPSET = 17 assert OPSET >= TRILU_OPSET # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") fp8_available, reason_for_no_fp8 = is_fp8_available() skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def create_fp8_recipe(): return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) def do_export( model: torch.nn.Module, inp: torch.Tensor, fname: str, use_fp8: bool=True, opset: int=OPSET, input_names: list=["input"], output_names: list=["output"], ): """Export to ONNX""" fp8_recipe = create_fp8_recipe() with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): warnings.filterwarnings( action='ignore', category=torch.jit.TracerWarning, module=r'.*' ) model.cuda().eval() os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) with te.onnx_export(True): torch.onnx.export( model, inps, fname, verbose=True, opset_version=opset, input_names=input_names, output_names=output_names, # Do not constant-fold because torch.onnx incorrectly folds # layer_norm(data, scale=add(gamma,1)) to layer_norm(data, scale=gamma) # when we use LN with zero-centered gamma. do_constant_folding=False, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) def to_numpy(tensor): return tensor.cpu().numpy() def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): """Initialize the FP8 quantization scales in module""" NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM module.fp8_init(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") * scale def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): """Transformer Engine forward prpoagtation. Return results after copying to the CPU and converting to numpy. """ fp8_recipe = create_fp8_recipe() with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) if not isinstance(te_outputs, tuple): te_outputs = (te_outputs,) te_outputs_np = [to_numpy(te_output) for te_output in te_outputs] return te_outputs_np def validate_result( fname: str, inps: Union[Tuple[torch.Tensor], torch.Tensor], model: torch.nn.Module, atol: float=1.e-8, # np.isclose default atol rtol: float=1.e-5, # np.isclose default rtol max_errors_printed: int=10, is_fp8: bool=False, allow_cnt_errors: int=0, input_names: list=["input"], output_names: list=["output"], ): """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX representation using ONNX Runtime (ORT) and ensure they are close. The purpose of the output comparison is to validate that TE models are converted to their correct ONNX representation by testing that TE and ORT outputs match within some small threshold (allowing for finite precision errors). Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, a very small number (0-3) of outliers. This is fine to do because these outliers are due to small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX representation (the tests assume both ORT or TE kernels are correct). """ def create_ort_session(fname: str, is_fp8: bool): def load_custom_ops(session_opts: ort.SessionOptions): """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" if not os.path.exists(ORT_CUSTOM_OPS_LIB): raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) print("registered custom FP8 Q/DQ ops!") """Create an ONNX Runtime session for validation.""" kwargs = {} if is_fp8: sess_options = ort.SessionOptions() load_custom_ops(sess_options) kwargs["sess_options"] = sess_options s = ort.InferenceSession(fname, **kwargs) return s def create_ort_input_dict(session, inps): inp_dict = {} if isinstance(inps, tuple) or isinstance(inps, list): nonetype_inputs = 0 for idx, inp in enumerate(inps): if inp is None: nonetype_inputs += 1 continue inp_dict[session.get_inputs()[idx - nonetype_inputs].name] = to_numpy(inp) else: inp_dict[session.get_inputs()[0].name] = to_numpy(inps) return inp_dict def serialize_inputs_outputs(fname, inputs, inputs_names, te_outputs, output_names): if not SAVE_TEST_IO: return inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) named_inputs = zip(inputs_names, inputs) input_data = [{k: to_numpy(v) for k, v in named_inputs if v is not None}] json_fname = fname[:-len(".onnx")] + "_inputs.json" save_json(input_data, json_fname, description="custom input data") if "bf16" in fname: return json_fname = fname[:-len(".onnx")] + "_output.json" named_outputs = zip(output_names, te_outputs) output_data = dict() for out_name, outp in named_outputs: if outp is not None: assert out_name not in output_data output_data[out_name] = outp custom_outputs = RunResults() custom_outputs.add([output_data], runner_name="custom_runner") custom_outputs.save(json_fname) def compare_outputs(onnx_outputs, te_outputs): """ Compare ORT and TE outputs.""" assert len(onnx_outputs) == len(te_outputs) # Compare ORT and PyTorch outputs. for onnx_output, te_output in zip(onnx_outputs, te_outputs): # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) mismatches = ac.nonzero() mismatched_ids = [loc for loc in zip(*mismatches)] if mismatched_ids: # Log some information in case of error. print("*" * 100) nb_errors = len(mismatched_ids) nb_vals = min(nb_errors, max_errors_printed) print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") print(f"Showing first {nb_vals} errors (ONNX -- TE):") abs_err = np.abs(onnx_output - te_output) errors = abs_err[mismatches] for loc in mismatched_ids[:nb_vals]: ref = te_output[loc] print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}") print(f"Max error: {np.max(errors)}") if nb_errors > allow_cnt_errors: raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") # Run ORT session and TE model. fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) te_outputs = te_infer(model, inps, is_fp8) ort_s = create_ort_session(fname, is_fp8) input_feed = create_ort_input_dict(ort_s, inps) onnx_outputs = ort_s.run(None, input_feed=input_feed) compare_outputs(onnx_outputs, te_outputs) serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names) def create_meta(scale_factor: float, size: int=1): meta = tex.FP8TensorMeta() meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor return meta def dtype2str(dtype: torch.dtype): return { torch.float32: "_fp32", torch.float16: "_fp16", torch.bfloat16: "_bf16", }[dtype] def as_te_type(dtype: torch.dtype): return { torch.float32: tex.DType.kFloat32, torch.float16: tex.DType.kFloat16, torch.bfloat16: tex.DType.kBFloat16, }[dtype] def get_attn_mask_str(use_mask, attn_mask_type): # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. if attn_mask_type is None: return "_mask" if use_mask else "_no-mask" attn_mask_str = "_padding-no-mask" attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str return attn_mask_str """ Tests cases begin here. """ @skip_FP8 @pytest.mark.parametrize("scale_factor", [1, 224]) @pytest.mark.parametrize( "precision, atol", [ [torch.float32, 1e-7], [torch.float16, 1e-7], [torch.bfloat16, 5e-3] ]) def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype): class TestFP8_QDQ(nn.Module): def __init__(self, fake_bf16_io): super().__init__() self.fp8_tensor = 0 self.meta = create_meta(scale_factor) self.highprec_type = as_te_type(precision) self.fp8_type = tex.DType.kFloat8E4M3 self.fake_bf16_io = fake_bf16_io def forward(self, inp): ret = cast_to_fp8( inp, self.meta, self.fp8_tensor, self.fp8_type) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 fake_bf16_io = precision == torch.bfloat16 inp = torch.randn(hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision) high_prec_str = dtype2str(precision) fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_QDQ(fake_bf16_io) do_export(model, inp, fname) validate_result(fname, inp, model, atol=atol, is_fp8=True) @skip_FP8 @pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize( "precision, atol", [ [torch.float32, 1e-5], [torch.float16, 1e-5], [torch.bfloat16, 5e-3] ]) def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): class TestFP8_Gelu(nn.Module): def __init__(self, fake_bf16_io): super().__init__() self.fp8_tensor = 0 self.meta = create_meta(scale_factor) self.highprec_type = as_te_type(precision) self.fp8_type = tex.DType.kFloat8E4M3 self.fake_bf16_io = fake_bf16_io def forward(self, inp): ret = fp8_gelu( inp, self.meta, self.fp8_tensor, self.fp8_type) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) if self.fake_bf16_io: ret = ret.type(torch.float32) return ret # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 fake_bf16_io = precision == torch.bfloat16 inp = torch.randn(hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision) high_prec_str = dtype2str(precision) fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" model = TestFP8_Gelu(fake_bf16_io) do_export(model, inp, fname) validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2) @pytest.mark.parametrize("scale_factors", [(224, 224,), ]) @pytest.mark.parametrize( "precision, use_fp8, use_bias, use_gelu", [ (torch.float32, False, False, False), (torch.float16, False, False, False), (torch.float32, False, True, False), (torch.float16, False, True, False), (torch.float32, False, True, True), (torch.float16, False, True, True), # For FP8 GEMM GeLU is not used. (torch.float32, True, False, False), (torch.float16, True, False, False), # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) (torch.float16, True, True, False), (torch.bfloat16, True, True, False), ]) def test_export_gemm( precision, # Precision of inputs, weights, output and bias use_fp8, use_bias, use_gelu, scale_factors ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) class TestFP8_GEMM(nn.Module): def __init__(self, precision, use_bias, gelu, scale_factors): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT nb_inp_scales, nb_weight_scales = 1, out_features act_scale_factor, weight_scale_factor = scale_factors self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) bias_size = nb_weight_scales self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision def forward(self, inp, weight): inp_fp8 = cast_to_fp8( inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type) ret = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, self.inp_type, inp_fp8, self.meta_inp.scale_inv, self.fp8_tensor_inp, self.weights_type, self.outp_type, get_workspace(), bias=self.bias, use_bias=self.use_bias, use_split_accumulator=False) return ret class Test_GEMM(nn.Module): def __init__(self, precision, use_bias=False, gelu=False): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision bias_size = out_features self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") def forward(self, inp, weight): outp_type = self.precision # note: due to logic in lines 104:116 and L129 in cpp_extensions.py # it appears either bias OR gelu can be activated, not both ret, _, _ = gemm( weight, inp, outp_type, get_workspace(), # test bias bias=self.bias, use_bias=self.use_bias, # test gelu gelu=self.gelu, gelu_input=self.gelu_input, grad=False, # only True for backward pass accumulate=False, ) return ret # If gelu is applied then bias must be added, as defined by TE kernel. if use_gelu: assert use_bias # Set dimensions (these are arbitrary). out_features = 128 hidden_size = 256 in_features = 64 inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" gelu_str = "_gelu" if use_gelu else "" high_prec_str = dtype2str(precision) fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" if use_fp8: model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors) do_export(model, (inp, weight), fname, use_fp8) if precision == torch.bfloat16: return validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True) else: model = Test_GEMM(precision, use_bias, use_gelu) do_export(model, (inp, weight), fname, use_fp8) validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2) @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_layernorm( use_fp8: bool, scale_factor: float, precision: torch.dtype, zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). inp_shape = [64, 32] class Test_Layernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") self.eps = 1e-6 # An arbitrary small value def forward(self, inp): ret = texcpp.layernorm_fwd_inf( inp, self.weight, self.bias, self.eps, zero_centered_gamma) return ret class TestFP8_Layernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 def forward(self, inp): ret = texcpp.layernorm_fwd_fp8_inf( inp, self.weight, self.bias, self.eps, self.meta, self.fp8_tensor, self.fp8_type, zero_centered_gamma) ret = cast_from_fp8( ret, self.meta, self.fp8_tensor, self.fp8_type, tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) return ret inp = torch.randn(*inp_shape, device="cuda", dtype=precision) model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() high_prec_str = dtype2str(precision) fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" do_export(model, inp, fname, use_fp8=use_fp8) if precision not in (torch.bfloat16, ): validate_result( fname, inp, model, atol=1e-7, is_fp8=use_fp8, allow_cnt_errors=3) @skip_FP8 @pytest.mark.parametrize("softmax_def", [ softmax_defs.ScaledUpperTriangMaskedSoftmax, softmax_defs.ScaledMaskedSoftmax, softmax_defs.ScaledSoftmax, ]) # Softmax kernel only supports FP16 or BF16! @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) def test_export_softmax(softmax_def, precision): class Test_Softmax(nn.Module): def __init__(self, softmax_function, mask_inp=False): super().__init__() self.softmax_fn = softmax_function self.mask_inp = mask_inp def forward(self, inp, mask): scale_factor = 8 # arbitrary value if self.mask_inp: ret = self.softmax_fn.apply(inp, mask, scale_factor) else: ret = self.softmax_fn.apply(inp, scale_factor) return ret # Set dimensions (these are arbitrary). in_features = 64 hidden_size = 256 mask = None input_names = ["input"] inp_shape = [hidden_size, in_features, in_features, in_features] if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax: inp_shape = [hidden_size, in_features, in_features] kernel_str = "ScaledUpperTriangMaskedSoftmax" model = Test_Softmax(softmax_def) elif softmax_def == softmax_defs.ScaledMaskedSoftmax: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision) mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) input_names.append("mask") kernel_str = "ScaledMaskedSoftmax" model = Test_Softmax(softmax_def, mask_inp=True) elif softmax_def == softmax_defs.ScaledSoftmax: kernel_str = "ScaledSoftmax" model = Test_Softmax(softmax_def) input_tensor = torch.randn(*inp_shape, device="cuda") input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half() high_prec_str = dtype2str(precision) fname = f"{kernel_str}{high_prec_str}.onnx" inp = (input_tensor, mask) do_export(model, inp, fname, input_names=input_names) if precision != torch.bfloat16: validate_result(fname, inp, model, atol=1e-3, input_names=input_names) @pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, False), (torch.float16, True), # Todo: cannot configure BF16 when bias is disabled (ORT issue?) (torch.bfloat16, False), # Todo: cannot configure BF16 when bias is enabled (ORT issue?) # (torch.bfloat16, True), ]) def test_export_linear( scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, precision: torch.dtype ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 class Test_Linear(nn.Module): def __init__(self, in_features, out_features, use_bias, return_bias, precision ): super().__init__() self.linear = te.Linear( in_features, out_features, bias=use_bias, return_bias=return_bias, params_dtype=precision ) def forward(self, inp): ret = self.linear(inp) return ret inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): model = Test_Linear( in_features, out_features, use_bias, return_bias, precision ).to(device='cuda') if use_fp8: set_layer_scale(model.linear, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) if precision in (torch.bfloat16, ): return if not use_fp8: validate_result(fname, inp, model, atol=1e-3) else: validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8) @pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, True), (torch.float16, False), ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_layernorm_linear( scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, return_layernorm_output: bool, precision: torch.dtype, zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): model = te.LayerNormLinear( hidden_size, 3 * hidden_size, bias=use_bias, return_bias=return_bias, return_layernorm_output=return_layernorm_output, params_dtype=precision, zero_centered_gamma=zero_centered_gamma, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) if not use_fp8: validate_result(fname, inp, model, atol=1e-3) elif precision != torch.bfloat16: validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8) @pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("use_fp8", [False, True]) # Returning the bias is a TE fusion optimization we don't care about. @pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize( "precision, use_bias",[ (torch.float32, False), (torch.float32, True), (torch.float16, True), (torch.float16, False), ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_layernorm_mlp( scale_factor: float, use_fp8: bool, use_bias: bool, return_bias: bool, return_layernorm_output: bool, precision: torch.dtype, zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 hidden_size = 256 ffn_hidden_size = 256 inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) fp8_str = "_fp8" if use_fp8 else "" bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx" with te.fp8_autocast(enabled=use_fp8): model = te.LayerNormMLP( hidden_size, ffn_hidden_size, bias=use_bias, return_bias=return_bias, return_layernorm_output=return_layernorm_output, params_dtype=precision, zero_centered_gamma=zero_centered_gamma, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=2) do_export(model, inp, fname, use_fp8) if not use_fp8: validate_result(fname, inp, model, atol=1e-3) else: validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8) @skip_FP8 @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ (torch.float32, False, None), # calls forward_torch_softmax (torch.float32, True, None), # calls forward_torch_softmax (torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (torch.float16, True, "padding"), # calls ScaledMaskedSoftmax (torch.float16, False, "padding"), # calls ScaledSoftmax ]) def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, ): # Set dimensions (these are arbitrary). kv_channels = 64 num_attention_heads = 1 qkv_size = (2048, 4, num_attention_heads, kv_channels) query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") input_names = ["query", "key", "value"] attention_mask = None if use_mask: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) input_names.append("attention_mask") inp = (query_layer, key_layer, value_layer, attention_mask) mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" if attn_mask_type is None: attn_mask_type = 'causal' inp = (query_layer, key_layer, value_layer) model = te.transformer.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, attn_mask_type=attn_mask_type, ).to(device='cuda') do_export(model, inp, fname, input_names=input_names, use_fp8=True) validate_result(fname, inp, model, atol=1e-2, input_names=input_names) test_configs_multihead_attention = [ #"use_mask, attn_mask_type" (False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax (False, "padding"), # calls ScaledSoftmax ] test_configs_attention_type = [ #"input_layernorm, attention_type, fuse_qkv_params" (True, "self", True), (False, "self", True), (True, "self", False), (False, "self", False), (True, "cross", True), (False, "cross", True), (True, "cross", False), (False, "cross", False), ] @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type) def test_export_multihead_attention( use_fp8: bool, use_mask: bool, attn_mask_type: str, precision: torch.dtype, return_layernorm_output: bool, input_layernorm: bool, attention_type: str, fuse_qkv_params: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) hidden_size = 256 sequence_length = 128 batch_size = 4 num_attention_heads = 32 kv_channels = 8 attention_dropout = 0.1 layernorm_epsilon = 1e-5 init_method = output_layer_init_method = get_default_init_method() attention_args = ( hidden_size, num_attention_heads, kv_channels, attention_dropout, layernorm_epsilon, init_method, output_layer_init_method, ) hidden_states = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) encoder_output = None if attention_type == "cross": encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") inp = (hidden_states, attention_mask, encoder_output) input_names = ["hidden_states", "attention_mask", "encoder_output"] output_names=["output", "output_1"] fp8_str = "_fp8" if use_fp8 else "" dtype_str = dtype2str(precision) attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) input_ln_str = "_input-ln" if input_layernorm else "" fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" model = te.transformer.MultiHeadAttention( *attention_args, attn_mask_type=attn_mask_type, params_dtype=precision, return_layernorm_output=return_layernorm_output, input_layernorm=input_layernorm, attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, ).to(device='cuda') do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names) if not use_fp8: validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names) elif precision == torch.float32: validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names) else: validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, allow_cnt_errors=3) @pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("output_layernorm", [ #True, # TO DO: handle this False ]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_transformer_layer( use_fp8: bool, use_mask: bool, attn_mask_type: str, output_layernorm: bool, precision: torch.dtype, fuse_qkv_params: bool, zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) # Layer configuration hidden_size = 64 sequence_length = 128 batch_size = 1 ffn_hidden_size = 256 num_attention_heads = 4 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_names = ["input"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) input_names.append("attention_mask") inp = (input_tensor, attention_mask) fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" high_prec_str = dtype2str(precision) attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" model = te.TransformerLayer( hidden_size, ffn_hidden_size, num_attention_heads, self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma).to(device='cuda') do_export(model, inp, fname, use_fp8) if not use_fp8: validate_result(fname, inp, model, atol=1e-3, input_names=input_names) else: validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names) @pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("ln_scale_factor", [448*2]) @pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) def test_export_gemm_layernorm( use_fp8: bool, ln_scale_factor: float, gemm_scale_factors: Tuple[float, float], precision: torch.dtype, zero_centered_gamma: bool ): """This is a regression test for testing that all LN inputs have the same type. The test sets up GEMM with FP32 output which feeds into an LN that is configured with FP16 or BF16 weights and bias. """ # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) class TestFP8_GemmLayernorm(nn.Module): def __init__(self) -> None: super().__init__() normalized_shape = torch.Size(inp.shape[1:]) self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") self.eps = 1e-6 # An arbitrary small value self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.meta = create_meta(ln_scale_factor) self.fp8_type = tex.DType.kFloat8E4M3 self.gemm = TestFP8_GEMM( precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors) def forward(self, inp, weight): x = self.gemm(inp, weight) x = texcpp.layernorm_fwd_fp8_inf( x, self.weight, self.bias, self.eps, self.meta, self.fp8_tensor, self.fp8_type, zero_centered_gamma) x = cast_from_fp8( x, self.meta, self.fp8_tensor, self.fp8_type, tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) return x out_features = 128 hidden_size = 128 in_features = 128 class TestFP8_GEMM(nn.Module): def __init__(self, precision, use_bias, gelu, scale_factors): super().__init__() self.use_bias = use_bias self.gelu = gelu self.precision = precision self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT nb_inp_scales, nb_weight_scales = 1, out_features act_scale_factor, weight_scale_factor = scale_factors self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) bias_size = nb_weight_scales self.bias = torch.randn(bias_size, dtype=precision, device="cuda") self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision def forward(self, inp, weight): inp_fp8 = cast_to_fp8( inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) weight_fp8 = cast_to_fp8( weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type) ret = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, self.inp_type, inp_fp8, self.meta_inp.scale_inv, self.fp8_tensor_inp, self.weights_type, self.outp_type, get_workspace(), bias=self.bias, use_bias=self.use_bias, use_split_accumulator=False) return ret inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") model = TestFP8_GemmLayernorm() high_prec_str = dtype2str(precision) fp8_str = f"_fp8" if use_fp8 else "" fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" do_export(model, (inp, weight), fname, use_fp8=use_fp8) if precision not in (torch.bfloat16, ): validate_result( fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2) @pytest.mark.parametrize("enabled", [True, False]) def test_export_ctx_manager(enabled): assert is_in_onnx_export_mode() == False with te.onnx_export(enabled): assert is_in_onnx_export_mode() == enabled assert is_in_onnx_export_mode() == False