Unverified Commit 1bc86400 authored by galagam's avatar galagam Committed by GitHub
Browse files

Fixes to test_onnx_export when saving input and output tensors (#173)



* Fixes to test_onnx_export when saving input and output tensors

- Allow saving i/o tensors when onnxruntime inference is skipped
- Support saving multiple outputs
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14c2b719
......@@ -150,6 +150,9 @@ def validate_result(
max_errors_printed: int=10,
is_fp8: bool=False,
allow_cnt_errors: int=0,
input_names: list=["input"],
output_names: list=["output"],
infer_ort=True
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
......@@ -195,17 +198,23 @@ def validate_result(
inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
return inp_dict
def serialize_inputs_outputs(fname, input_feed, te_outputs):
def serialize_inputs_outputs(fname, inputs, inputs_names, te_outputs, output_names):
if not SAVE_TEST_IO:
return
input_data = [{k: v for k,v in input_feed.items()}]
named_inputs = zip(inputs_names, inputs)
input_data = [{k: to_numpy(v) for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
for i, outp in enumerate(te_outputs):
if outp is not None and "bf16" not in fname:
if "bf16" in fname:
return
json_fname = fname[:-len(".onnx")] + "_output.json"
output_data = {"output": outp}
named_outputs = zip(output_names, te_outputs)
output_data = dict()
for out_name, outp in named_outputs:
if outp is not None:
assert out_name not in output_data
output_data[out_name] = outp
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
......@@ -237,12 +246,13 @@ def validate_result(
# Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
te_outputs = te_infer(model, inps, is_fp8)
if infer_ort:
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
te_outputs = te_infer(model, inps, is_fp8)
serialize_inputs_outputs(fname, input_feed, te_outputs)
compare_outputs(onnx_outputs, te_outputs)
serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names)
def create_meta(scale_factor: float, size: int=1):
......@@ -490,8 +500,10 @@ def test_export_gemm(
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8)
if precision not in (torch.bfloat16, torch.float16):
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True)
if precision == torch.bfloat16:
return
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True, infer_ort=infer_ort)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8)
......@@ -624,7 +636,7 @@ def test_export_softmax(softmax_def, precision):
inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3)
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
@pytest.mark.parametrize("scale_factor", [1])
......@@ -740,6 +752,7 @@ def test_export_layernorm_linear(
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormLinear(
hidden_size,
......@@ -863,7 +876,7 @@ def test_export_core_attention(
fname,
input_names=input_names,
use_fp8=True)
validate_result(fname, inp, model, atol=1e-2)
validate_result(fname, inp, model, atol=1e-2, input_names=input_names)
test_configs_multihead_attention = [
......@@ -935,6 +948,7 @@ def test_export_multihead_attention(
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["output", "output_1"]
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
......@@ -953,11 +967,12 @@ def test_export_multihead_attention(
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names)
do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names)
else:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, infer_ort=infer_ort)
@pytest.mark.parametrize("use_fp8", [False, True])
......@@ -1015,10 +1030,11 @@ def test_export_transformer_layer(
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names, infer_ort=infer_ort)
@pytest.mark.parametrize("use_fp8", [True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment