Unverified Commit 83911ddb authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

ONNX export refactoring (#197)



* ONNX export refactoring

* Remove infer_ort (to enable more testing)
* Add BF16 ORT tests for Q/DQ ops and GELU.
  * Use FP32 i/o instead of BF16 (because ORT doesn't support BF16 i/o) and add casts from FP32 to BF16 (this is only for subgraph inputs and outputs).
  * We'll need to add more BF16 testing.
* GEMM:
  * Add cast after DQ to achieve better performance (matmul at sub-fp32 precisions).
  * Fold bias into Gemm operation (=> smaller graphs)
  * Wrap GEMM-GELU with FP32 (TE implements GELU in FP32)
* Enable tests for cross attention (test_export_multihead_attention)
* Reduce test thresholds for test_export_layernorm_mlp, test_export_layernorm_linear, test_export_layernorm
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Loosen MHA export validation thresholds for FP16
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
parent a2c9c635
......@@ -69,6 +69,7 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
......@@ -152,7 +153,6 @@ def validate_result(
allow_cnt_errors: int=0,
input_names: list=["input"],
output_names: list=["output"],
infer_ort=True
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
......@@ -248,7 +248,6 @@ def validate_result(
# Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
te_outputs = te_infer(model, inps, is_fp8)
if infer_ort:
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
......@@ -290,20 +289,28 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return attn_mask_str
"""
Tests cases begin here.
"""
@skip_FP8
@pytest.mark.parametrize("scale_factor, atol", [
(1, 1e-7),
(224, 1e-7)
@pytest.mark.parametrize("scale_factor", [1, 224])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3]
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype):
class TestFP8_QDQ(nn.Module):
def __init__(self):
def __init__(self, fake_bf16_io):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
ret = cast_to_fp8(
......@@ -318,34 +325,39 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ()
model = TestFP8_QDQ(fake_bf16_io)
do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True)
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-5],
[torch.float16, 1e-5]
[torch.float16, 1e-5],
[torch.bfloat16, 5e-3]
])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
class TestFP8_Gelu(nn.Module):
def __init__(self):
def __init__(self, fake_bf16_io):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
ret = fp8_gelu(
......@@ -359,15 +371,19 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu()
model = TestFP8_Gelu(fake_bf16_io)
do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2)
......@@ -481,7 +497,8 @@ def test_export_gemm(
# test gelu
gelu=self.gelu,
gelu_input=self.gelu_input,
grad=False # only True for backward pass
grad=False, # only True for backward pass
accumulate=False,
)
return ret
......@@ -503,8 +520,7 @@ def test_export_gemm(
do_export(model, (inp, weight), fname, use_fp8)
if precision == torch.bfloat16:
return
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True, infer_ort=infer_ort)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8)
......@@ -584,7 +600,7 @@ def test_export_layernorm(
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
validate_result(
fname, inp, model, atol=1e-4, is_fp8=use_fp8, allow_cnt_errors=3)
fname, inp, model, atol=1e-7, is_fp8=use_fp8, allow_cnt_errors=3)
@skip_FP8
......@@ -769,8 +785,8 @@ def test_export_layernorm_linear(
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision not in (torch.bfloat16,):
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
elif precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112])
......@@ -826,7 +842,7 @@ def test_export_layernorm_mlp(
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
else:
validate_result(fname, inp, model, atol=2e-2, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
@skip_FP8
@pytest.mark.parametrize(
......@@ -892,13 +908,10 @@ test_configs_attention_type = [
(False, "self", True),
(True, "self", False),
(False, "self", False),
# disabled because query_bias (reqd for cross attention) is defined when fuse_qkv_params is False
# (True, "cross", True),
# (False, "cross", True),
(True, "cross", True),
(False, "cross", True),
(True, "cross", False),
# disabled because TypeError: cannot assign 'transformer_engine.pytorch.module.Linear'
# as parameter 'query' (torch.nn.Parameter or None expected)
# (False, "cross", False),
(False, "cross", False),
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
......@@ -969,12 +982,14 @@ def test_export_multihead_attention(
fuse_qkv_params=fuse_qkv_params,
).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names)
elif precision == torch.float32:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names)
else:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, infer_ort=infer_ort)
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
......@@ -1031,11 +1046,10 @@ def test_export_transformer_layer(
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names, infer_ort=infer_ort)
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names)
@pytest.mark.parametrize("use_fp8", [True])
......
......@@ -29,14 +29,12 @@ from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
# Custom ops spec version
VER = 1
UNSPECIFIED_TYPE = -1
......@@ -139,10 +137,10 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
wrapped_gelu = lambda inputs: torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=False)
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True)
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
......@@ -179,27 +177,26 @@ def onnx_te_gemm(
# pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs)
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, out_type)
empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size
if not bias_empty:
if pre_gelu_out_empty:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
if not bias_empty:
if not pre_gelu_out_empty:
wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
output = compute_in_fp32(g, output, wrapped_gelu, cast_outp=True)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment