Unverified Commit 3951b9f3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add utility to find model labels (#16526)



* Add utility to find model labels

* Use it in the Trainer

* Update src/transformers/utils/generic.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Quality
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent ec4da72f
......@@ -67,7 +67,6 @@ from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enab
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
......@@ -124,6 +123,7 @@ from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
CONFIG_NAME,
WEIGHTS_NAME,
find_labels,
get_full_repo_name,
is_apex_available,
is_datasets_available,
......@@ -495,11 +495,7 @@ class Trainer:
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
["start_positions", "end_positions"]
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"]
)
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
......
......@@ -37,6 +37,7 @@ from .generic import (
PaddingStrategy,
TensorType,
cached_property,
find_labels,
is_tensor,
to_numpy,
to_py_obj,
......
......@@ -15,6 +15,7 @@
Generic utilities
"""
import inspect
from collections import OrderedDict, UserDict
from contextlib import ExitStack
from dataclasses import fields
......@@ -289,3 +290,23 @@ class ContextManagers:
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
def find_labels(model_class):
"""
Find the labels used by a given model.
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
if "QuestionAnswering" in model_name:
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
......@@ -35,10 +35,14 @@ from transformers.utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels,
get_file_from_repo,
get_from_cache,
has_file,
hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
)
......@@ -158,24 +162,51 @@ class GetFromCacheTests(unittest.TestCase):
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
class ContextManagerTests(unittest.TestCase):
class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout):
def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout):
def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout):
def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
def test_find_labels(self):
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
if is_flax_available():
# Flax models don't have labels
from transformers import (
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
)
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment