Unverified Commit 44d64abc authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Add a temporary workaround to layernorm ONNX export (#95)



* Add a temporary workaround to layernorm export

Seems like ORT is performing template-matching for LN and incorrectly concludes
that it doesn't have a kernel for FP32 LN. The work-around adds the addition of
fake_zero which is meant to prevent the template matching while keeping the graph
virtually unchanged. This also requires `do_constant_folding=False` in
`torch.onnx.export`.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Adjust test threshold
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Opened an ORT bug and added the link for tracking
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Fix Python linter errors
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Simplify the LN workaround solution (ONNX export)

After discussing https://github.com/microsoft/onnxruntime/issues/15021


with Microsoft engineers, replaced the LN workaround with a simpler
implementation.

In addition:
* To make test more robust add `allow_cnt_errors` to `validate_result`
* Add more documentation to clarify the purpose and methodology of the
ONNX export tests
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Fix unused import
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Fix unused import
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Fix unused import
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 05366e5f
......@@ -4,13 +4,21 @@
"""
This file contains tests for exporting TransformerEngine models to ONNX.
The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.
Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
"""
import os
import pytest
import warnings
import numpy as np
import math
import onnxruntime as ort
import torch
from torch import nn as nn
......@@ -76,7 +84,7 @@ def do_export(
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
......@@ -114,8 +122,20 @@ def validate_result(
rtol: float=1.e-5, # np.isclose default rtol
max_errors_printed: int=10,
is_fp8: bool=False,
allow_cnt_errors: int=0,
):
"""Validate the outputs of an ONNX model vs. ONNX Runtime."""
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
The purpose of the output comparison is to validate that TE models are converted to
their correct ONNX representation by testing that TE and ORT outputs match within some
small threshold (allowing for finite precision errors).
Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
a very small number (0-3) of outliers. This is fine to do because these outliers are due to
small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
representation (the tests assume both ORT or TE kernels are correct).
"""
def create_ort_session(fname: str, is_fp8: bool):
def load_custom_ops(session_opts: ort.SessionOptions):
......@@ -126,13 +146,15 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
# Workaround an ORT limitation. See https://github.com/microsoft/onnxruntime/issues/15021
kwargs = {"disabled_optimizers": ["LayerNormFusion"]}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
# Model loading successfully indicates that the custom op node could be resolved successfully
s = ort.InferenceSession(fname, sess_options=sess_options)
else:
s = ort.InferenceSession(fname)
kwargs["sess_options"] = sess_options
s = ort.InferenceSession(fname, **kwargs)
return s
def create_ort_input_dict(session, inps):
......@@ -174,6 +196,7 @@ def validate_result(
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
if len(mismatched_ids) > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors")
......@@ -502,8 +525,8 @@ def test_export_layernorm(
fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
# TODO: FP32 has a small threshold (1e-5)
validate_result(fname, inp, model, atol=4e-3, is_fp8=use_fp8)
validate_result(
fname, inp, model, atol=1e-4, is_fp8=use_fp8, allow_cnt_errors=3)
@skip_FP8
......
......@@ -23,11 +23,13 @@ the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic
from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
......@@ -179,8 +181,16 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
normalized_shape = normalized_shape[1:]
if zero_centered_gamma:
one = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64, device="cuda"))
one = g.op("Constant", value_t=torch.tensor([1.], dtype=torch.float, device="cuda"))
weight = g.op("Add", weight, one)
# TE computes LN using float32 precision so wrap the LN subgraph with
# conversion to/from float32.
input_dtype = _type_utils.JitScalarType.from_value(inputs)
is_fp32 = input_dtype == _type_utils.JitScalarType.FLOAT
if not is_fp32:
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
......@@ -190,6 +200,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
eps,
False # cudnn_enable (not relevant)
)
if not is_fp32:
ln = g.op("Cast", ln, to_i=_type_utils.JitScalarType(input_dtype).onnx_type())
return ln
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment