"examples/vscode:/vscode.git/clone" did not exist on "64e6098094d063687f90d3bf49bdc7571551c344"
Unverified Commit 7edf8bfa authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve forward signature test (#27729)



* First draft

* Extend test_forward_signature

* Update tests/test_modeling_common.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Revert suggestion

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent bcd0a91a
......@@ -542,6 +542,12 @@ class ModelTesterMixin:
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
else:
expected_arg_names = [model.main_input_name]
self.assertListEqual(arg_names[:1], expected_arg_names)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment