Unverified Commit fee8970f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

ONNX export for ReLU and GLU variants (#281)



* ReLU ONNX export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add GLU variants
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* linter check
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Export reglu, geglu, swiglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3f13e55f
......@@ -69,6 +69,8 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
@pytest.fixture()
def seed_default_rng():
......@@ -974,6 +976,7 @@ def test_export_layernorm_linear(
(torch.bfloat16, False),
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
......@@ -982,7 +985,8 @@ def test_export_layernorm_mlp(
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -998,7 +1002,7 @@ def test_export_layernorm_mlp(
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx"
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormMLP(
hidden_size,
......@@ -1008,6 +1012,7 @@ def test_export_layernorm_mlp(
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2)
......@@ -1016,10 +1021,9 @@ def test_export_layernorm_mlp(
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize(
......@@ -1211,6 +1215,7 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
......@@ -1220,7 +1225,8 @@ def test_export_transformer_layer(
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -1256,18 +1262,15 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
zero_centered_gamma=zero_centered_gamma,
activation=activation).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names,
te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs)
@pytest.mark.parametrize("use_fp8", [True])
......@@ -1405,7 +1408,7 @@ def test_export_gpt_generation(
set_max_seq_len,
use_fp8: bool,
precision: torch.dtype,
zero_centered_gamma: bool
zero_centered_gamma: bool,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask it adjusted on-the-fly to different sequence lengths.
......
......@@ -17,6 +17,7 @@ from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_fp8_relu,
onnx_te_gemm,
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
......
......@@ -23,12 +23,16 @@ the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils
import torch._C._onnx as _C_onnx
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx._internal import jit_utils
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
......@@ -141,8 +145,87 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True)
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
if scale_inv:
gelu = quantize(g, gelu, scale_inv, fp8_tensor)
return gelu
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu"""
# pylint: disable=unused-argument
wrapped_relu = lambda inps: torch.onnx.symbolic_opset9.relu(g, inps)
relu = compute_in_fp32(g, inputs, wrapped_relu, cast_outp=True)
if scale_inv:
relu = quantize(g, relu, scale_inv, fp8_tensor)
return relu
@symbolic_helper.parse_args("v", "i")
def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for swiglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Sigmoid", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument
wrapped_swiglu = lambda inps: onnx_swiglu(g, inps, 1)
swiglu = compute_in_fp32(g, inputs, wrapped_swiglu, cast_outp=True)
if scale_inv:
swiglu = quantize(g, swiglu, scale_inv, fp8_tensor)
return swiglu
@symbolic_helper.parse_args("v", "i")
def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for reglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Relu", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument
wrapped_reglu = lambda inps: onnx_reglu(g, inps, 1)
reglu = compute_in_fp32(g, inputs, wrapped_reglu, cast_outp=True)
if scale_inv:
reglu = quantize(g, reglu, scale_inv, fp8_tensor)
return reglu
@symbolic_helper.parse_args("v", "i")
def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for geglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
first_gelu = torch.onnx.symbolic_opset9.gelu(g, first, "tanh")
return g.op("Mul", first_gelu, second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument
wrapped_geglu = lambda inps: onnx_geglu(g, inps, 1)
geglu = compute_in_fp32(g, inputs, wrapped_geglu, cast_outp=True)
if scale_inv:
geglu = quantize(g, geglu, scale_inv, fp8_tensor)
return geglu
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
......@@ -255,6 +338,10 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)
register_custom_op_symbolic('tex_ts::reglu_ts', onnx_fp8_reglu, VER)
register_custom_op_symbolic('tex_ts::geglu_ts', onnx_fp8_geglu, VER)
register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment