Unverified Commit fee8970f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

ONNX export for ReLU and GLU variants (#281)



* ReLU ONNX export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add GLU variants
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* linter check
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Export reglu, geglu, swiglu
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3f13e55f
...@@ -69,6 +69,8 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") ...@@ -69,6 +69,8 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
@pytest.fixture() @pytest.fixture()
def seed_default_rng(): def seed_default_rng():
...@@ -974,6 +976,7 @@ def test_export_layernorm_linear( ...@@ -974,6 +976,7 @@ def test_export_layernorm_linear(
(torch.bfloat16, False), (torch.bfloat16, False),
]) ])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_layernorm_mlp( def test_export_layernorm_mlp(
seed_default_rng, seed_default_rng,
scale_factor: float, scale_factor: float,
...@@ -982,7 +985,8 @@ def test_export_layernorm_mlp( ...@@ -982,7 +985,8 @@ def test_export_layernorm_mlp(
return_bias: bool, return_bias: bool,
return_layernorm_output: bool, return_layernorm_output: bool,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool zero_centered_gamma: bool,
activation: str,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -998,7 +1002,7 @@ def test_export_layernorm_mlp( ...@@ -998,7 +1002,7 @@ def test_export_layernorm_mlp(
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormMLP( model = te.LayerNormMLP(
hidden_size, hidden_size,
...@@ -1008,6 +1012,7 @@ def test_export_layernorm_mlp( ...@@ -1008,6 +1012,7 @@ def test_export_layernorm_mlp(
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation,
).to(device='cuda') ).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2) set_layer_scale(model, scale_factor, num_gemms=2)
...@@ -1016,10 +1021,9 @@ def test_export_layernorm_mlp( ...@@ -1016,10 +1021,9 @@ def test_export_layernorm_mlp(
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16, ):
return return
if not use_fp8: atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1211,6 +1215,7 @@ def test_export_multihead_attention( ...@@ -1211,6 +1215,7 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer( def test_export_transformer_layer(
seed_default_rng, seed_default_rng,
set_max_seq_len, set_max_seq_len,
...@@ -1220,7 +1225,8 @@ def test_export_transformer_layer( ...@@ -1220,7 +1225,8 @@ def test_export_transformer_layer(
output_layernorm: bool, output_layernorm: bool,
precision: torch.dtype, precision: torch.dtype,
fuse_qkv_params: bool, fuse_qkv_params: bool,
zero_centered_gamma: bool zero_centered_gamma: bool,
activation: str,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -1256,18 +1262,15 @@ def test_export_transformer_layer( ...@@ -1256,18 +1262,15 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma,
activation=activation).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names) do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16, ):
return return
if not use_fp8: atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs)
te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
@pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("use_fp8", [True])
...@@ -1405,7 +1408,7 @@ def test_export_gpt_generation( ...@@ -1405,7 +1408,7 @@ def test_export_gpt_generation(
set_max_seq_len, set_max_seq_len,
use_fp8: bool, use_fp8: bool,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool zero_centered_gamma: bool,
): ):
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask it adjusted on-the-fly to different sequence lengths. the attention mask it adjusted on-the-fly to different sequence lengths.
......
...@@ -17,6 +17,7 @@ from .te_onnx_extensions import ( ...@@ -17,6 +17,7 @@ from .te_onnx_extensions import (
onnx_cast_to_fp8, onnx_cast_to_fp8,
onnx_cast_from_fp8, onnx_cast_from_fp8,
onnx_fp8_gelu, onnx_fp8_gelu,
onnx_fp8_relu,
onnx_te_gemm, onnx_te_gemm,
onnx_layernorm_fwd_fp8, onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd, onnx_layernorm_fwd,
......
...@@ -23,12 +23,16 @@ the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8 ...@@ -23,12 +23,16 @@ the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8
TypeError: 'torch._C.Value' object is not subscriptable TypeError: 'torch._C.Value' object is not subscriptable
""" """
import torch import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx._internal import jit_utils
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols. # This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = [] __all__ = []
...@@ -141,8 +145,87 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): ...@@ -141,8 +145,87 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
# TE computes GELU using float32 precision so wrap the GELU subgraph with # TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32. # conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True) gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True)
out = quantize(g, gelu, scale_inv, fp8_tensor) if scale_inv:
return out gelu = quantize(g, gelu, scale_inv, fp8_tensor)
return gelu
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu"""
# pylint: disable=unused-argument
wrapped_relu = lambda inps: torch.onnx.symbolic_opset9.relu(g, inps)
relu = compute_in_fp32(g, inputs, wrapped_relu, cast_outp=True)
if scale_inv:
relu = quantize(g, relu, scale_inv, fp8_tensor)
return relu
@symbolic_helper.parse_args("v", "i")
def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for swiglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Sigmoid", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument
wrapped_swiglu = lambda inps: onnx_swiglu(g, inps, 1)
swiglu = compute_in_fp32(g, inputs, wrapped_swiglu, cast_outp=True)
if scale_inv:
swiglu = quantize(g, swiglu, scale_inv, fp8_tensor)
return swiglu
@symbolic_helper.parse_args("v", "i")
def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for reglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Relu", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument
wrapped_reglu = lambda inps: onnx_reglu(g, inps, 1)
reglu = compute_in_fp32(g, inputs, wrapped_reglu, cast_outp=True)
if scale_inv:
reglu = quantize(g, reglu, scale_inv, fp8_tensor)
return reglu
@symbolic_helper.parse_args("v", "i")
def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for geglu"""
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
first_gelu = torch.onnx.symbolic_opset9.gelu(g, first, "tanh")
return g.op("Mul", first_gelu, second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument
wrapped_geglu = lambda inps: onnx_geglu(g, inps, 1)
geglu = compute_in_fp32(g, inputs, wrapped_geglu, cast_outp=True)
if scale_inv:
geglu = quantize(g, geglu, scale_inv, fp8_tensor)
return geglu
@symbolic_helper.parse_args("v", "fs", "i", "i", "i", @symbolic_helper.parse_args("v", "fs", "i", "i", "i",
...@@ -255,6 +338,10 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -255,6 +338,10 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER) register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::relu_ts', onnx_fp8_relu, VER)
register_custom_op_symbolic('tex_ts::reglu_ts', onnx_fp8_reglu, VER)
register_custom_op_symbolic('tex_ts::geglu_ts', onnx_fp8_geglu, VER)
register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment