Unverified Commit 19a8a303 authored by Matt's avatar Matt Committed by GitHub
Browse files

Add magic method to our TF models to convert datasets with column inference (#17160)



* Add method to call to_tf_dataset() with column inference

* Add test for dataset creation

* Add a default arg for data collator

* Fix test

* Fix call with non-dev version of datasets

* Test correct column removal too

* make fixup

* More tests to make sure we remove unwanted columns

* Fix test to avoid predicting on unbuilt models

* Fix test to avoid predicting on unbuilt models

* Fix test to remove unwanted head mask columns from inputs

* Stop pushing your debug breakpoints to the main repo of the $2bn company you work for

* Skip the test in convnext because no grouped conv support

* Drop bools from the dataset dict

* Make style

* Skip the training test for models whose input dicts don't give us labels

* Skip transformerXL in the test because it doesn't return a simple loss

* Skip TFTapas because of some odd NaN losses

* make style

* make fixup

* Add docstring

* fixup

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove breakpoint from tests

* Fix assert, add requires_backends

* Protect tokenizer import with if TYPE_CHECKING

* make fixup

* Add noqa, more fixup

* More rearranging for ~* aesthetics *~

* Adding defaults for shuffle and batch_size to match to_tf_dataset()

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d28b7aa8
...@@ -22,7 +22,7 @@ import pickle ...@@ -22,7 +22,7 @@ import pickle
import re import re
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
from typing import Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np import numpy as np
...@@ -35,6 +35,7 @@ from tensorflow.python.keras.saving import hdf5_format ...@@ -35,6 +35,7 @@ from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files from huggingface_hub import Repository, list_repo_files
from requests import HTTPError from requests import HTTPError
from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
...@@ -58,9 +59,14 @@ from .utils import ( ...@@ -58,9 +59,14 @@ from .utils import (
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging, logging,
requires_backends,
) )
if TYPE_CHECKING:
from . import PreTrainedTokenizerBase
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
tf_logger = tf.get_logger() tf_logger = tf.get_logger()
...@@ -892,6 +898,94 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -892,6 +898,94 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# set it directly, but the user can pass it to fit(). # set it directly, but the user can pass it to fit().
return {"epoch": extra_data["epoch"]} return {"epoch": extra_data["epoch"]}
def prepare_tf_dataset(
self,
dataset: "datasets.Dataset", # noqa:F821
batch_size: int = 8,
shuffle: bool = True,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
collate_fn: Optional[Callable] = None,
collate_fn_args: Optional[Dict[str, Any]] = None,
drop_remainder: Optional[bool] = None,
prefetch: bool = True,
):
"""
Wraps a HuggingFace `datasets.Dataset` as a `tf.data.Dataset` with collation and batching. This method is
designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
further modification. The method will drop columns from the dataset if they don't match input names for the
model. If you want to specify the column names to return rather than using the names that match this model, we
recommend using `Dataset.to_tf_dataset()` instead.
Args:
dataset (`Any`):
A `datasets.Dataset` to be wrapped as a `tf.data.Dataset`.
batch_size (`int`, defaults to 8):
The size of batches to return.
shuffle (`bool`, defaults to `True`):
Whether to return samples from the dataset in random order. Usually `True` for training datasets and
`False` for validation/test datasets.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
`collate_fn` is passed instead.
collate_fn (`Callable`, *optional*):
A function that collates samples from the dataset into a single batch. Defaults to
`DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
passed.
collate_fn_args (`Dict[str, Any]`, *optional*):
A dict of arguments to pass to the `collate_fn` alongside the list of samples.
drop_remainder (`bool`, *optional*):
Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
to the same setting as `shuffle`.
prefetch (`bool`, defaults to `True`):
Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
performance, but can be disabled in edge cases.
Returns:
`Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
"""
requires_backends(self, ["datasets"])
import datasets
if collate_fn is None:
if tokenizer is None:
collate_fn = DefaultDataCollator(return_tensors="tf")
else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
if collate_fn_args is None:
collate_fn_args = dict()
if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_labels = find_labels(self.__class__)
unwanted_columns = [
feature
for feature in dataset.features
if feature not in model_inputs and feature not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature(
dataset,
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
)
output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
label_cols = [col for col in output_columns if col in model_labels]
tf_dataset = dataset.to_tf_dataset(
columns=feature_cols,
label_cols=label_cols,
batch_size=batch_size,
shuffle=shuffle,
drop_remainder=drop_remainder,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
prefetch=prefetch,
)
return tf_dataset
def compile( def compile(
self, self,
optimizer="rmsprop", optimizer="rmsprop",
......
...@@ -174,6 +174,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -174,6 +174,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
def test_attention_outputs(self): def test_attention_outputs(self):
pass pass
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_dataset_conversion(self):
super().test_dataset_conversion()
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
......
...@@ -498,6 +498,10 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -498,6 +498,10 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
@unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
def test_dataset_conversion(self):
pass
def prepare_tapas_single_inputs_for_inference(): def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on: # Here we prepare a single table-question pair to test TAPAS inference on:
......
...@@ -216,6 +216,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -216,6 +216,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFTransfoXLModel.from_pretrained(model_name) model = TFTransfoXLModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
def test_dataset_conversion(self):
pass
@require_tf @require_tf
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
......
...@@ -25,6 +25,8 @@ import unittest.mock as mock ...@@ -25,6 +25,8 @@ import unittest.mock as mock
from importlib import import_module from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
from datasets import Dataset
from huggingface_hub import delete_repo, login from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
...@@ -1509,6 +1511,56 @@ class TFModelTesterMixin: ...@@ -1509,6 +1511,56 @@ class TFModelTesterMixin:
observed_main_input_name = list(model_signature.parameters.keys())[1] observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name) self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_dataset_conversion(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
if "head_mask" not in key and isinstance(val, tf.Tensor)
}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
)
test_batch = next(iter(tf_dataset))
if isinstance(test_batch, tf.Tensor):
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
else:
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
self.assertNotIn("extra_unwanted_column", test_batch)
for tensor in test_batch.values():
self.assertTrue(isinstance(tensor, tf.Tensor))
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
model(test_batch, training=False)
if "labels" in inspect.signature(model_class.call).parameters.keys():
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
if "labels" not in tf_inputs_dict:
return # This model isn't giving us labels after all, don't try training with it
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
)
test_batch, test_batch_labels = next(iter(tf_dataset))
self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
if isinstance(test_batch, dict):
self.assertNotIn("extra_unwanted_column", test_batch)
if isinstance(test_batch_labels, dict):
self.assertNotIn("extra_unwanted_column", test_batch_labels)
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
special_tokens = [] special_tokens = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment