Unverified Commit 66e86567 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CLI: Print all different tensors on exception (#17612)

parent e9d51387
...@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=convert_command_factory) train_parser.set_defaults(func=convert_command_factory)
@staticmethod @staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input): def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
""" """
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
difference and its source. differences.
""" """
pt_outputs = pt_model(**pt_input, output_hidden_states=True) pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True) tf_outputs = tf_model(**tf_input, output_hidden_states=True)
...@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f" {tf_out_attrs})" f" {tf_out_attrs})"
) )
# 2. For each output attribute, ALL values must be the same # 2. For each output attribute, computes the difference
def _compate_pt_tf_models(pt_out, tf_out, attr_name=""): def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
max_difference = 0
max_difference_source = ""
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute. # recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)): if isinstance(pt_out, torch.Tensor):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy())) tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference: differences[attr_name] = tensor_difference
max_difference = difference
max_difference_source = attr_name
else: else:
root_name = attr_name root_name = attr_name
for i, pt_item in enumerate(pt_out): for i, pt_item in enumerate(pt_out):
...@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
else: else:
branch_name = root_name + f"[{i}]" branch_name = root_name + f"[{i}]"
tf_item = tf_out[i] tf_item = tf_out[i]
difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name) differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)
if difference > max_difference:
max_difference = difference
max_difference_source = difference_source
return max_difference, max_difference_source return differences
return _compate_pt_tf_models(pt_outputs, tf_outputs) return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args): def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._logger = logging.get_logger("transformers-cli/pt_to_tf")
...@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)}) tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
# Confirms that cross loading PT weights into TF worked. # Confirms that cross loading PT weights into TF worked.
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input) crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR: max_crossload_diff = max(crossload_differences.values())
if max_crossload_diff > MAX_ERROR:
raise ValueError( raise ValueError(
"The cross-loaded TF model has different outputs, something went wrong! (max difference =" "The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f" {crossload_diff:.3e}, observed in {diff_source})" f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
)
) )
# Save the weights in a TF format (if needed) and confirms that the results are still good # Save the weights in a TF format (if needed) and confirms that the results are still good
...@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_from_pt_model.save_weights(tf_weights_path) tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir) tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input) conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
if converted_diff >= MAX_ERROR: max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR:
raise ValueError( raise ValueError(
"The converted TF model has different outputs, something went wrong! (max difference =" "The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f" {converted_diff:.3e}, observed in {diff_source})" f" tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
)
) )
if not self._no_pr: if not self._no_pr:
...@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
create_pr=True, create_pr=True,
pr_commit_summary="Add TF weights", pr_commit_summary="Add TF weights",
pr_commit_description=( pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};" "Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
f" Max converted output difference={converted_diff:.3e}." " hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
), ),
) )
self._logger.warn(f"PR open in {hub_pr_url}") self._logger.warn(f"PR open in {hub_pr_url}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment