Unverified Commit 1bc86400 authored by galagam's avatar galagam Committed by GitHub
Browse files

Fixes to test_onnx_export when saving input and output tensors (#173)



* Fixes to test_onnx_export when saving input and output tensors

- Allow saving i/o tensors when onnxruntime inference is skipped
- Support saving multiple outputs
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14c2b719
...@@ -150,6 +150,9 @@ def validate_result( ...@@ -150,6 +150,9 @@ def validate_result(
max_errors_printed: int=10, max_errors_printed: int=10,
is_fp8: bool=False, is_fp8: bool=False,
allow_cnt_errors: int=0, allow_cnt_errors: int=0,
input_names: list=["input"],
output_names: list=["output"],
infer_ort=True
): ):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close. representation using ONNX Runtime (ORT) and ensure they are close.
...@@ -195,20 +198,26 @@ def validate_result( ...@@ -195,20 +198,26 @@ def validate_result(
inp_dict[session.get_inputs()[0].name] = to_numpy(inps) inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
return inp_dict return inp_dict
def serialize_inputs_outputs(fname, input_feed, te_outputs): def serialize_inputs_outputs(fname, inputs, inputs_names, te_outputs, output_names):
if not SAVE_TEST_IO: if not SAVE_TEST_IO:
return return
input_data = [{k: v for k,v in input_feed.items()}] named_inputs = zip(inputs_names, inputs)
input_data = [{k: to_numpy(v) for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json" json_fname = fname[:-len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data") save_json(input_data, json_fname, description="custom input data")
for i, outp in enumerate(te_outputs): if "bf16" in fname:
if outp is not None and "bf16" not in fname: return
json_fname = fname[:-len(".onnx")] + "_output.json" json_fname = fname[:-len(".onnx")] + "_output.json"
output_data = {"output": outp} named_outputs = zip(output_names, te_outputs)
custom_outputs = RunResults() output_data = dict()
custom_outputs.add([output_data], runner_name="custom_runner") for out_name, outp in named_outputs:
custom_outputs.save(json_fname) if outp is not None:
assert out_name not in output_data
output_data[out_name] = outp
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def compare_outputs(onnx_outputs, te_outputs): def compare_outputs(onnx_outputs, te_outputs):
""" Compare ORT and TE outputs.""" """ Compare ORT and TE outputs."""
...@@ -237,12 +246,13 @@ def validate_result( ...@@ -237,12 +246,13 @@ def validate_result(
# Run ORT session and TE model. # Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
te_outputs = te_infer(model, inps, is_fp8) te_outputs = te_infer(model, inps, is_fp8)
serialize_inputs_outputs(fname, input_feed, te_outputs) if infer_ort:
compare_outputs(onnx_outputs, te_outputs) ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(onnx_outputs, te_outputs)
serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names)
def create_meta(scale_factor: float, size: int=1): def create_meta(scale_factor: float, size: int=1):
...@@ -490,8 +500,10 @@ def test_export_gemm( ...@@ -490,8 +500,10 @@ def test_export_gemm(
if use_fp8: if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors) model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8) do_export(model, (inp, weight), fname, use_fp8)
if precision not in (torch.bfloat16, torch.float16): if precision == torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True) return
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True, infer_ort=infer_ort)
else: else:
model = Test_GEMM(precision, use_bias, use_gelu) model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8) do_export(model, (inp, weight), fname, use_fp8)
...@@ -624,7 +636,7 @@ def test_export_softmax(softmax_def, precision): ...@@ -624,7 +636,7 @@ def test_export_softmax(softmax_def, precision):
inp = (input_tensor, mask) inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names) do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16: if precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
@pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("scale_factor", [1])
...@@ -740,6 +752,7 @@ def test_export_layernorm_linear( ...@@ -740,6 +752,7 @@ def test_export_layernorm_linear(
bias_str = "_bias" if use_bias else "" bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormLinear( model = te.LayerNormLinear(
hidden_size, hidden_size,
...@@ -863,7 +876,7 @@ def test_export_core_attention( ...@@ -863,7 +876,7 @@ def test_export_core_attention(
fname, fname,
input_names=input_names, input_names=input_names,
use_fp8=True) use_fp8=True)
validate_result(fname, inp, model, atol=1e-2) validate_result(fname, inp, model, atol=1e-2, input_names=input_names)
test_configs_multihead_attention = [ test_configs_multihead_attention = [
...@@ -935,6 +948,7 @@ def test_export_multihead_attention( ...@@ -935,6 +948,7 @@ def test_export_multihead_attention(
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output) inp = (hidden_states, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"] input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["output", "output_1"]
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision) dtype_str = dtype2str(precision)
...@@ -953,11 +967,12 @@ def test_export_multihead_attention( ...@@ -953,11 +967,12 @@ def test_export_multihead_attention(
attention_type=attention_type, attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
).to(device='cuda') ).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names)
elif precision != torch.float16: else:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8) validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8, input_names=input_names, output_names=output_names, infer_ort=infer_ort)
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
...@@ -1015,10 +1030,11 @@ def test_export_transformer_layer( ...@@ -1015,10 +1030,11 @@ def test_export_transformer_layer(
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
infer_ort = precision != torch.float16 # temporarily skipping onnxrt inference due to input type mismatch bug
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
elif precision != torch.float16: else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8) validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names, infer_ort=infer_ort)
@pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("use_fp8", [True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment