Unverified Commit 62d84760 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Update TF multiple choice example (#15868)

parent ab2f8d12
......@@ -24,10 +24,9 @@ import sys
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import datasets
import numpy as np
import tensorflow as tf
from datasets import load_dataset
......@@ -37,12 +36,15 @@ from transformers import (
TF2_WEIGHTS_NAME,
AutoConfig,
AutoTokenizer,
DefaultDataCollator,
HfArgumentParser,
TFAutoModelForMultipleChoice,
TFTrainingArguments,
create_optimizer,
set_seed,
)
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import check_min_version
......@@ -65,51 +67,61 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
self.model.save_pretrained(self.output_dir)
def convert_dataset_for_tensorflow(
dataset, non_label_column_names, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
):
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
def densify_ragged_batch(features, label=None):
features = {
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items()
}
if label is None:
return features
else:
return features, label
feature_keys = list(set(dataset.features.keys()) - set(non_label_column_names + ["label"]))
if dataset_mode == "variable_batch":
batch_shape = {key: None for key in feature_keys}
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
elif dataset_mode == "constant_batch":
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
batch_shape = {
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
for key, ragged_tensor in data.items()
}
else:
raise ValueError("Unknown dataset mode!")
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="tf",
)
if "label" in dataset.features:
labels = tf.convert_to_tensor(np.array(dataset["label"]))
tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
else:
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
if shuffle:
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
tf_dataset = (
tf_dataset.with_options(options)
.batch(batch_size=batch_size, drop_remainder=drop_remainder)
.map(densify_ragged_batch)
)
return tf_dataset
# Un-flatten
batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
# Add back labels
batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
return batch
# endregion
......@@ -382,6 +394,12 @@ def main():
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.pad_to_max_length:
data_collator = DefaultDataCollator(return_tensors="tf")
else:
# custom class defined above, as HF has no data collator for multiple choice
data_collator = DataCollatorForMultipleChoice(tokenizer)
# endregion
with training_args.strategy.scope():
......@@ -417,12 +435,26 @@ def main():
# region Training
if training_args.do_train:
tf_train_dataset = convert_dataset_for_tensorflow(
train_dataset, non_label_column_names=non_label_columns, batch_size=total_train_batch_size
dataset_exclude_cols = set(non_label_columns + ["label"])
tf_train_dataset = train_dataset.to_tf_dataset(
columns=[col for col in train_dataset.column_names if col not in dataset_exclude_cols],
shuffle=True,
batch_size=total_train_batch_size,
collate_fn=data_collator,
drop_remainder=True,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in train_dataset.column_names else None,
)
if training_args.do_eval:
validation_data = convert_dataset_for_tensorflow(
eval_dataset, non_label_column_names=non_label_columns, batch_size=total_eval_batch_size
validation_data = eval_dataset.to_tf_dataset(
columns=[col for col in eval_dataset.column_names if col not in dataset_exclude_cols],
shuffle=False,
batch_size=total_eval_batch_size,
collate_fn=data_collator,
drop_remainder=True,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in eval_dataset.column_names else None,
)
else:
validation_data = None
......@@ -436,9 +468,16 @@ def main():
# region Evaluation
if training_args.do_eval and not training_args.do_train:
dataset_exclude_cols = set(non_label_columns + ["label"])
# Do a standalone evaluation pass
tf_eval_dataset = convert_dataset_for_tensorflow(
eval_dataset, non_label_column_names=non_label_columns, batch_size=total_eval_batch_size
tf_eval_dataset = eval_dataset.to_tf_dataset(
columns=[col for col in eval_dataset.column_names if col not in dataset_exclude_cols],
shuffle=False,
batch_size=total_eval_batch_size,
collate_fn=data_collator,
drop_remainder=True,
# `label_cols` is needed for user-defined losses, such as in this example
label_cols="label" if "label" in eval_dataset.column_names else None,
)
model.evaluate(tf_eval_dataset)
# endregion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment