Unverified Commit 441fa968 authored by galagam's avatar galagam Committed by GitHub
Browse files

ONNX export test - minor fixes (#200)



* ONNX export - input names fix

* Fix discrepencies due to input names not defined correctly/not passed to export
* Refactor ORT input feed creation for simplicity
* Control whether to save test IO files via environment variable
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* ONNX export test: minor refactor
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatargalagam <96368689+galagam@users.noreply.github.com>

---------
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: default avatargalagam <96368689+galagam@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 73c9f421
......@@ -44,7 +44,8 @@ from transformer_engine.pytorch.fp8 import is_fp8_available
# Global test configuration knobs.
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO = False
SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0")))
if SAVE_TEST_IO:
from polygraphy.json import save_json
from polygraphy.comparator import RunResults
......@@ -96,7 +97,12 @@ def do_export(
model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
assert len(inps) == len(input_names)
inds_to_del = [i for i in range(len(inps)) if inps[i] is None]
input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del]
with te.onnx_export(True):
torch.onnx.export(
model,
......@@ -185,17 +191,11 @@ def validate_result(
s = ort.InferenceSession(fname, **kwargs)
return s
def create_ort_input_dict(session, inps):
inp_dict = {}
if isinstance(inps, tuple) or isinstance(inps, list):
nonetype_inputs = 0
for idx, inp in enumerate(inps):
if inp is None:
nonetype_inputs += 1
continue
inp_dict[session.get_inputs()[idx - nonetype_inputs].name] = to_numpy(inp)
else:
inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
def create_ort_input_dict(session, inputs):
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
input_names = [x.name for x in session.get_inputs()]
inps = [to_numpy(x) for x in inputs if x is not None]
inp_dict = dict(zip(input_names, inps))
return inp_dict
def serialize_inputs_outputs(fname, inputs, inputs_names, te_outputs, output_names):
......@@ -515,16 +515,17 @@ def test_export_gemm(
gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
input_names = ['input', 'weight']
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
if precision == torch.bfloat16:
return
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True, input_names=input_names)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, input_names=input_names)
@pytest.mark.parametrize("use_fp8", [False, True])
......@@ -630,7 +631,7 @@ def test_export_softmax(softmax_def, precision):
in_features = 64
hidden_size = 256
mask = None
input_names = ["input"]
input_names = ["input", "mask"]
inp_shape = [hidden_size, in_features, in_features, in_features]
if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
......@@ -640,7 +641,6 @@ def test_export_softmax(softmax_def, precision):
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("mask")
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_def, mask_inp=True)
elif softmax_def == softmax_defs.ScaledSoftmax:
......@@ -866,13 +866,12 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value"]
input_names = ["query", "key", "value", "attention_mask"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (query_layer, key_layer, value_layer, attention_mask)
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
......@@ -881,6 +880,7 @@ def test_export_core_attention(
if attn_mask_type is None:
attn_mask_type = 'causal'
input_names = ["query", "key", "value"]
inp = (query_layer, key_layer, value_layer)
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
......@@ -958,6 +958,7 @@ def test_export_multihead_attention(
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output)
......@@ -1021,13 +1022,12 @@ def test_export_transformer_layer(
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input"]
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if use_fp8 else ""
......@@ -1045,7 +1045,7 @@ def test_export_transformer_layer(
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8)
do_export(model, inp, fname, use_fp8, input_names=input_names)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
else:
......@@ -1167,10 +1167,11 @@ def test_export_gemm_layernorm(
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8" if use_fp8 else ""
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, (inp, weight), fname, use_fp8=use_fp8)
input_names = ['input', 'weight']
do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
if precision not in (torch.bfloat16, ):
validate_result(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2)
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, input_names=input_names)
@pytest.mark.parametrize("enabled", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment