Unverified Commit b921c0d1 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix model load exception when state resides on GPU (#140)



* Fix model load exception when state resides on GPU

- Whenever converting a torch.tensor to numpy, we need to first
migrate the tensor storage to the host CPU.

- Add a warning not to do contant-folding when exporting to ONNX.
This is due to a torch.onnx export bug.

- Refactor compare_outputs
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Onnx export: Improve remark text
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ec1030b5
......@@ -12,6 +12,12 @@ validate the output against TE's output.
Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
To run many repetitive tests use pytest-loop:
$ python3 -m pip install pytest-loop
$ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm
For reproducability use: torch.manual_seed(0)
"""
......@@ -89,15 +95,19 @@ def do_export(
fname = os.path.join(TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
with te.onnx_export(True):
torch.onnx.export(model,
inps,
fname,
verbose=False,
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
torch.onnx.export(
model,
inps,
fname,
verbose=True,
opset_version=opset,
input_names=input_names,
output_names=output_names,
# Do not constant-fold because torch.onnx incorrectly folds
# layer_norm(data, scale=add(gamma,1)) to layer_norm(data, scale=gamma)
# when we use LN with zero-centered gamma.
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
def to_numpy(tensor):
......@@ -201,26 +211,27 @@ def validate_result(
def compare_outputs(onnx_outputs, te_outputs):
""" Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# Compare ORT and PyTorch outputs.
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
print(onnx_output.shape)
nb_vals = min(len(mismatched_ids), max_errors_printed)
print(f"Detected {len(mismatched_ids)} diverging values.\nShowing first {nb_vals} errors (ONNX -- TE):")
abs_err = abs(onnx_output - te_output)
nb_errors = len(mismatched_ids)
nb_vals = min(nb_errors, max_errors_printed)
print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
print(f"Showing first {nb_vals} errors (ONNX -- TE):")
abs_err = np.abs(onnx_output - te_output)
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
if len(mismatched_ids) > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors")
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
# Run ORT session and TE model.
fname = os.path.join(TEST_ARTIFACTS_DIR, fname)
......
......@@ -318,7 +318,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().numpy().tobytes())
state = pickle.loads(state.detach().cpu().numpy().tobytes())
if state is None:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment