"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "2af163d9c7d5c0276b7b2385034c1dd9181b187c"
Unverified Commit e4a84a8d authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Add an option to serialize test i/o to file (ONNX export tests) (#107)



Add an option to serialize test i/o to file

Small refactoring of the inferencing code.
Change the default directory where generated ONNX files are stored.
Use the temp directory to avoid clogging the file system.
Add an option to serialize test input/output tensors to a
Polygraphy RunResults object.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 549666ae
......@@ -16,6 +16,7 @@ using custom ORT operations.
import os
import tempfile
import pytest
import warnings
import numpy as np
......@@ -33,12 +34,19 @@ import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
# Directory where generated ONNX test models are stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
ONNX_FILES_DIR = os.path.join(TESTS_DIR, "./gen_onnx_models")
# Global test configuration knobs.
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO = False
if SAVE_TEST_IO:
from polygraphy.json import save_json
from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored.
TEST_ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), "./gen_onnx_models")
# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14.
TRILU_OPSET = 14
......@@ -46,6 +54,9 @@ TRILU_OPSET = 14
OPSET = 15
assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
skip_FP8 = pytest.mark.skipif(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
reason="Device compute capability 9.x required for FP8 execution.",
......@@ -75,10 +86,11 @@ def do_export(
)
model.cuda().eval()
os.makedirs(ONNX_FILES_DIR, exist_ok=True)
fname = os.path.join(ONNX_FILES_DIR, fname)
os.makedirs(TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
torch.onnx.export(model,
inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,),
inps,
fname,
verbose=False,
opset_version=opset,
......@@ -170,13 +182,23 @@ def validate_result(
inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
return inp_dict
# Run ORT session and TE model.
fname = os.path.join(ONNX_FILES_DIR, fname)
ort_s = create_ort_session(fname, is_fp8)
onnx_outputs = ort_s.run(None, input_feed=create_ort_input_dict(ort_s, inps))
te_outputs = te_infer(model, inps, is_fp8)
# Compare ORT and TE outputs.
def serialize_inputs_outputs(fname, input_feed, te_outputs):
if not SAVE_TEST_IO:
return
input_data = [{k: v for k,v in input_feed.items()}]
json_fname = fname[:-len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
for i, outp in enumerate(te_outputs):
if outp is not None and "bf16" not in fname:
json_fname = fname[:-len(".onnx")] + "_output.json"
output_data = {"output": outp}
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def compare_outputs(onnx_outputs, te_outputs):
""" Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
......@@ -199,6 +221,15 @@ def validate_result(
if len(mismatched_ids) > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors")
# Run ORT session and TE model.
fname = os.path.join(TEST_ARTIFACTS_DIR, fname)
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
te_outputs = te_infer(model, inps, is_fp8)
serialize_inputs_outputs(fname, input_feed, te_outputs)
compare_outputs(onnx_outputs, te_outputs)
def create_meta(scale_factor: float, size: int=1):
meta = tex.FP8TensorMeta()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment