Unverified Commit c523a869 authored by Rak Alexey's avatar Rak Alexey Committed by GitHub
Browse files

fix marianMT convertion to onnx (#19287)



* fix marianMT convertion to onnx

* Update src/transformers/onnx/convert.py
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* Update src/transformers/onnx/convert.py
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 34107057
......@@ -392,3 +392,7 @@ class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
@property
def atol_for_validation(self) -> float:
return 1e-4
......@@ -450,10 +450,12 @@ def validate_model_outputs(
# Values
if not np.allclose(ref_value, ort_value, atol=atol):
bad_indices = np.logical_not(np.isclose(ref_value, ort_value, atol=atol))
logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
raise ValueError(
"Outputs values doesn't match between reference model and ONNX exported model: "
f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))} for "
f"{ref_value[bad_indices]} vs {ort_value[bad_indices]}"
)
else:
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment