"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "571fa585b6b2e7d377c8cc6e3a15a07c7af1e368"
Unverified Commit 7edf8bfa authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve forward signature test (#27729)



* First draft

* Extend test_forward_signature

* Update tests/test_modeling_common.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Revert suggestion

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent bcd0a91a
...@@ -542,6 +542,12 @@ class ModelTesterMixin: ...@@ -542,6 +542,12 @@ class ModelTesterMixin:
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
self.assertListEqual(arg_names, expected_arg_names)
else: else:
expected_arg_names = [model.main_input_name] expected_arg_names = [model.main_input_name]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment