Unverified Commit 3951b9f3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add utility to find model labels (#16526)



* Add utility to find model labels

* Use it in the Trainer

* Update src/transformers/utils/generic.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Quality
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent ec4da72f
...@@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab ...@@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
...@@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments ...@@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
find_labels,
get_full_repo_name, get_full_repo_name,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
...@@ -495,11 +495,7 @@ class Trainer: ...@@ -495,11 +495,7 @@ class Trainer:
self.current_flos = 0 self.current_flos = 0
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
default_label_names = ( default_label_names = find_labels(self.model.__class__)
["start_positions", "end_positions"]
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
......
...@@ -37,6 +37,7 @@ from .generic import ( ...@@ -37,6 +37,7 @@ from .generic import (
PaddingStrategy, PaddingStrategy,
TensorType, TensorType,
cached_property, cached_property,
find_labels,
is_tensor, is_tensor,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
Generic utilities Generic utilities
""" """
import inspect
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import fields from dataclasses import fields
...@@ -289,3 +290,23 @@ class ContextManagers: ...@@ -289,3 +290,23 @@ class ContextManagers:
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs) self.stack.__exit__(*args, **kwargs)
def find_labels(model_class):
"""
Find the labels used by a given model.
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
if "QuestionAnswering" in model_name:
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
...@@ -35,10 +35,14 @@ from transformers.utils import ( ...@@ -35,10 +35,14 @@ from transformers.utils import (
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
filename_to_url, filename_to_url,
find_labels,
get_file_from_repo, get_file_from_repo,
get_from_cache, get_from_cache,
has_file, has_file,
hf_bucket_url, hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
) )
...@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
class ContextManagerTests(unittest.TestCase): class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout): def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]): with ContextManagers([]):
print("Transformers are awesome!") print("Transformers are awesome!")
# The print statement adds a new line at the end of the output # The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n") self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout): def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]): with ContextManagers([context_en()]):
print("Transformers are awesome!") print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye # The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n") self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout): def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]): with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!") print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye # The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n") self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
def test_find_labels(self):
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
if is_flax_available():
# Flax models don't have labels
from transformers import (
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
)
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment