"nndet/vscode:/vscode.git/clone" did not exist on "31d06041a0c73396749b2fadb5aedd6cf152ded0"
Unverified Commit 83911ddb authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

ONNX export refactoring (#197)



* ONNX export refactoring

* Remove infer_ort (to enable more testing)
* Add BF16 ORT tests for Q/DQ ops and GELU.
  * Use FP32 i/o instead of BF16 (because ORT doesn't support BF16 i/o) and add casts from FP32 to BF16 (this is only for subgraph inputs and outputs).
  * We'll need to add more BF16 testing.
* GEMM:
  * Add cast after DQ to achieve better performance (matmul at sub-fp32 precisions).
  * Fold bias into Gemm operation (=> smaller graphs)
  * Wrap GEMM-GELU with FP32 (TE implements GELU in FP32)
* Enable tests for cross attention (test_export_multihead_attention)
* Reduce test thresholds for test_export_layernorm_mlp, test_export_layernorm_linear, test_export_layernorm
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Loosen MHA export validation thresholds for FP16
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
parent a2c9c635
...@@ -69,6 +69,7 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") ...@@ -69,6 +69,7 @@ ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def create_fp8_recipe(): def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
...@@ -152,7 +153,6 @@ def validate_result( ...@@ -152,7 +153,6 @@ def validate_result(
allow_cnt_errors: int=0, allow_cnt_errors: int=0,
input_names: list=["input"], input_names: list=["input"],
output_names: list=["output"], output_names: list=["output"],
infer_ort=True
): ):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close. representation using ONNX Runtime (ORT) and ensure they are close.
...@@ -248,11 +248,10 @@ def validate_result( ...@@ -248,11 +248,10 @@ def validate_result(
# Run ORT session and TE model. # Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
te_outputs = te_infer(model, inps, is_fp8) te_outputs = te_infer(model, inps, is_fp8)
if infer_ort: ort_s = create_ort_session(fname, is_fp8)
ort_s = create_ort_session(fname, is_fp8) input_feed = create_ort_input_dict(ort_s, inps)
input_feed = create_ort_input_dict(ort_s, inps) onnx_outputs = ort_s.run(None, input_feed=input_feed)
onnx_outputs = ort_s.run(None, input_feed=input_feed) compare_outputs(onnx_outputs, te_outputs)
compare_outputs(onnx_outputs, te_outputs)
serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names) serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names)
...@@ -290,20 +289,28 @@ def get_attn_mask_str(use_mask, attn_mask_type): ...@@ -290,20 +289,28 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return attn_mask_str return attn_mask_str
"""
Tests cases begin here.
"""
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("scale_factor, atol", [ @pytest.mark.parametrize("scale_factor", [1, 224])
(1, 1e-7), @pytest.mark.parametrize(
(224, 1e-7) "precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3]
]) ])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype): def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype):
class TestFP8_QDQ(nn.Module): class TestFP8_QDQ(nn.Module):
def __init__(self): def __init__(self, fake_bf16_io):
super().__init__() super().__init__()
self.fp8_tensor = 0 self.fp8_tensor = 0
self.meta = create_meta(scale_factor) self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision) self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3 self.fp8_type = tex.DType.kFloat8E4M3
self.fake_bf16_io = fake_bf16_io
def forward(self, inp): def forward(self, inp):
ret = cast_to_fp8( ret = cast_to_fp8(
...@@ -318,34 +325,39 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp ...@@ -318,34 +325,39 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
self.highprec_type) self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret return ret
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
hidden_size = 256 hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ() model = TestFP8_QDQ(fake_bf16_io)
do_export(model, inp, fname) do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True) validate_result(fname, inp, model, atol=atol, is_fp8=True)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, atol", [ "precision, atol", [
[torch.float32, 1e-5], [torch.float32, 1e-5],
[torch.float16, 1e-5] [torch.float16, 1e-5],
[torch.bfloat16, 5e-3]
]) ])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
class TestFP8_Gelu(nn.Module): class TestFP8_Gelu(nn.Module):
def __init__(self): def __init__(self, fake_bf16_io):
super().__init__() super().__init__()
self.fp8_tensor = 0 self.fp8_tensor = 0
self.meta = create_meta(scale_factor) self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision) self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3 self.fp8_type = tex.DType.kFloat8E4M3
self.fake_bf16_io = fake_bf16_io
def forward(self, inp): def forward(self, inp):
ret = fp8_gelu( ret = fp8_gelu(
...@@ -359,15 +371,19 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -359,15 +371,19 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
self.highprec_type) self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret return ret
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
hidden_size = 256 hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu() model = TestFP8_Gelu(fake_bf16_io)
do_export(model, inp, fname) do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2) validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2)
...@@ -481,7 +497,8 @@ def test_export_gemm( ...@@ -481,7 +497,8 @@ def test_export_gemm(
# test gelu # test gelu
gelu=self.gelu, gelu=self.gelu,
gelu_input=self.gelu_input, gelu_input=self.gelu_input,
grad=False # only True for backward pass grad=False, # only True for backward pass
accumulate=False,
) )
return ret return ret
...@@ -503,8 +520,7 @@ def test_export_gemm( ...@@ -503,8 +520,7 @@ def test_export_gemm(
do_export(model, (inp, weight), fname, use_fp8) do_export(model, (inp, weight), fname, use_fp8)
if precision == torch.bfloat16: if precision == torch.bfloat16:
return return
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True, infer_ort=infer_ort)
else: else:
model = Test_GEMM(precision, use_bias, use_gelu) model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8) do_export(model, (inp, weight), fname, use_fp8)
...@@ -584,7 +600,7 @@ def test_export_layernorm( ...@@ -584,7 +600,7 @@ def test_export_layernorm(
do_export(model, inp, fname, use_fp8=use_fp8) do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ): if precision not in (torch.bfloat16, ):
validate_result( validate_result(
fname, inp, model, atol=1e-4, is_fp8=use_fp8, allow_cnt_errors=3) fname, inp, model, atol=1e-7, is_fp8=use_fp8, allow_cnt_errors=3)
@skip_FP8 @skip_FP8
...@@ -769,8 +785,8 @@ def test_export_layernorm_linear( ...@@ -769,8 +785,8 @@ def test_export_layernorm_linear(
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3)
elif precision not in (torch.bfloat16,): elif precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8) validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("scale_factor", [112])
...@@ -826,7 +842,7 @@ def test_export_layernorm_mlp( ...@@ -826,7 +842,7 @@ def test_export_layernorm_mlp(
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3)
else: else:
validate_result(fname, inp, model, atol=2e-2, is_fp8=use_fp8) validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -892,13 +908,10 @@ test_configs_attention_type = [ ...@@ -892,13 +908,10 @@ test_configs_attention_type = [
(False, "self", True), (False, "self", True),
(True, "self", False), (True, "self", False),
(False, "self", False), (False, "self", False),
# disabled because query_bias (reqd for cross attention) is defined when fuse_qkv_params is False (True, "cross", True),
# (True, "cross", True), (False, "cross", True),
# (False, "cross", True),
(True, "cross", False), (True, "cross", False),
# disabled because TypeError: cannot assign 'transformer_engine.pytorch.module.Linear' (False, "cross", False),
# as parameter 'query' (torch.nn.Parameter or None expected)
# (False, "cross", False),
] ]
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
...@@ -969,12 +982,14 @@ def test_export_multihead_attention( ...@@ -969,12 +982,14 @@ def test_export_multihead_attention(
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
).to(device='cuda') ).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names) validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names)
elif precision == torch.float32:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names)
else: else:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, infer_ort=infer_ort) validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
...@@ -1031,11 +1046,10 @@ def test_export_transformer_layer( ...@@ -1031,11 +1046,10 @@ def test_export_transformer_layer(
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names) validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
else: else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names, infer_ort=infer_ort) validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names)
@pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("use_fp8", [True])
......
...@@ -29,14 +29,12 @@ from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils ...@@ -29,14 +29,12 @@ from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols. # This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = [] __all__ = []
# Custom ops spec version # Custom ops spec version
VER = 1 VER = 1
UNSPECIFIED_TYPE = -1 UNSPECIFIED_TYPE = -1
...@@ -139,10 +137,10 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): ...@@ -139,10 +137,10 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu""" """ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_gelu = lambda inputs: torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh") wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with # TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32. # conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=False) gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True)
out = quantize(g, gelu, scale_inv, fp8_tensor) out = quantize(g, gelu, scale_inv, fp8_tensor)
return out return out
...@@ -179,27 +177,26 @@ def onnx_te_gemm( ...@@ -179,27 +177,26 @@ def onnx_te_gemm(
# pylint: disable=unused-argument # pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs) is_fp16 = is_dtype_fp16(inputs)
if input_type == int(tex.DType.kFloat8E4M3): if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE) inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
if weight_type == int(tex.DType.kFloat8E4M3): if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE) weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, out_type)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
empty_tensor_size = [0] empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \ pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size == empty_tensor_size
if not bias_empty:
output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight)
else:
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
if not bias_empty: if not bias_empty:
if pre_gelu_out_empty: if not pre_gelu_out_empty:
if is_fp16: wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) # TE computes GELU using float32 precision so wrap the GELU subgraph with
output = g.op('Add', output, bias) # conversion to/from float32.
else: output = compute_in_fp32(g, output, wrapped_gelu, cast_outp=True)
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
else: else:
if is_fp16: if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment