Unverified Commit 66e86567 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CLI: Print all different tensors on exception (#17612)

parent e9d51387
......@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=convert_command_factory)
@staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed
difference and its source.
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
differences.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
......@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f" {tf_out_attrs})"
)
# 2. For each output attribute, ALL values must be the same
def _compate_pt_tf_models(pt_out, tf_out, attr_name=""):
max_difference = 0
max_difference_source = ""
# 2. For each output attribute, computes the difference
def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference:
max_difference = difference
max_difference_source = attr_name
if isinstance(pt_out, torch.Tensor):
tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
differences[attr_name] = tensor_difference
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
......@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name)
if difference > max_difference:
max_difference = difference
max_difference_source = difference_source
differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)
return max_difference, max_difference_source
return differences
return _compate_pt_tf_models(pt_outputs, tf_outputs)
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
......@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
# Confirms that cross loading PT weights into TF worked.
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR:
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
max_crossload_diff = max(crossload_differences.values())
if max_crossload_diff > MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f" {crossload_diff:.3e}, observed in {diff_source})"
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
)
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
......@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
if converted_diff >= MAX_ERROR:
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR:
raise ValueError(
"The converted TF model has different outputs, something went wrong! (max difference ="
f" {converted_diff:.3e}, observed in {diff_source})"
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f" tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
)
)
if not self._no_pr:
......@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
f" Max converted output difference={converted_diff:.3e}."
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment