Unverified Commit 06486a00 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix GELU ONNX export (#111)



* Fix GELU ONNX export

* Wrap GELU export with cast to/from FP32 to achieve same compute precision as TE.
* Increase GELU export test thresholds.
* Change export to ONNX opset 17 for smaller representation of LN (single node instead of subgraph).
* Remove the need for LN work-around for ORT
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Add docstring to te_onnx_extensions.py::compute_in_fp32
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Tune threshold for GELU ONNX export

Ran 8K test instances to verify the threshold.
Allow 2 coefficients to escape threshold. Two wrong coefficients
are not a failure.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5a881a08
...@@ -51,7 +51,7 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) ...@@ -51,7 +51,7 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. # ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14.
TRILU_OPSET = 14 TRILU_OPSET = 14
# Opset used in the ONNX files generated by the tests. # Opset used in the ONNX files generated by the tests.
OPSET = 15 OPSET = 17
assert OPSET >= TRILU_OPSET assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
...@@ -158,9 +158,7 @@ def validate_result( ...@@ -158,9 +158,7 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!") print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation.""" """Create an ONNX Runtime session for validation."""
# Workaround an ORT limitation. See https://github.com/microsoft/onnxruntime/issues/15021 kwargs = {}
kwargs = {"disabled_optimizers": ["LayerNormFusion"]}
if is_fp8: if is_fp8:
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
load_custom_ops(sess_options) load_custom_ops(sess_options)
...@@ -310,8 +308,8 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp ...@@ -310,8 +308,8 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp
@pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, atol", [ "precision, atol", [
[torch.float32, 1e-7], [torch.float32, 1e-5],
[torch.float16, 2e-3] [torch.float16, 1e-5]
]) ])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
class TestFP8_Gelu(nn.Module): class TestFP8_Gelu(nn.Module):
...@@ -344,7 +342,7 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -344,7 +342,7 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu() model = TestFP8_Gelu()
do_export(model, inp, fname) do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=1e-1, atol=atol, is_fp8=True) validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2)
@pytest.mark.parametrize("scale_factors", @pytest.mark.parametrize("scale_factors",
...@@ -486,7 +484,7 @@ def test_export_gemm( ...@@ -486,7 +484,7 @@ def test_export_gemm(
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm( def test_export_layernorm(
use_fp8: bool, use_fp8: bool,
...@@ -665,7 +663,6 @@ def test_export_linear( ...@@ -665,7 +663,6 @@ def test_export_linear(
ret = self.linear(inp) ret = self.linear(inp)
return ret return ret
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
......
...@@ -76,6 +76,22 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype): ...@@ -76,6 +76,22 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
return out return out
def compute_in_fp32(g, inp, subgraph, cast_outp):
"""Wrap subgraph with casts to/from FP32 so that its precision is FP32.
If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`.
Then, if `cast_output` is true, cast subgraphs's output back to `inp` data type.
"""
inp_dtype = _type_utils.JitScalarType.from_value(inp)
is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT
if not is_fp32:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
sg_out = subgraph(inp)
if not is_fp32 and cast_outp:
sg_out = g.op("Cast", sg_out, to_i=_type_utils.JitScalarType(inp_dtype).onnx_type())
return sg_out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8""" """ONNX graph for cast_to_fp8"""
...@@ -94,7 +110,10 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): ...@@ -94,7 +110,10 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu""" """ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh") wrapped_gelu = lambda inputs: torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=False)
out = quantize(g, gelu, scale_inv, fp8_tensor) out = quantize(g, gelu, scale_inv, fp8_tensor)
return out return out
...@@ -181,28 +200,21 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -181,28 +200,21 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
normalized_shape = normalized_shape[1:] normalized_shape = normalized_shape[1:]
if zero_centered_gamma: if zero_centered_gamma:
one = g.op("Constant", value_t=torch.tensor([1.], dtype=torch.float, device="cuda")) inputs_dtype= inputs.type().dtype()
one = g.op("Constant", value_t=torch.tensor([1.], dtype=inputs_dtype, device="cuda"))
weight = g.op("Add", weight, one) weight = g.op("Add", weight, one)
# TE computes LN using float32 precision so wrap the LN subgraph with axis = -len(normalized_shape)
# conversion to/from float32. ln = g.op(
input_dtype = _type_utils.JitScalarType.from_value(inputs) "LayerNormalization",
is_fp32 = input_dtype == _type_utils.JitScalarType.FLOAT
if not is_fp32:
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs, inputs,
normalized_shape,
weight, weight,
bias, bias,
eps, epsilon_f=eps,
False # cudnn_enable (not relevant) axis_i=axis,
# This sets the LN compute precision - use FP32 always as does TE.
stash_type_i=_C_onnx.TensorProtoDataType.FLOAT,
) )
if not is_fp32:
ln = g.op("Cast", ln, to_i=_type_utils.JitScalarType(input_dtype).onnx_type())
return ln return ln
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment