Unverified Commit 54f9fbef authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
parent 3f94170a
# Examples
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models)
......
......@@ -204,6 +204,8 @@ if is_tf_available():
)
def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset
def __len__(self):
......
......@@ -21,6 +21,8 @@ import os
from dataclasses import dataclass, field
from typing import Optional
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
......@@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field(
default=128,
metadata={
......@@ -170,7 +173,7 @@ def main():
)
# Get datasets
if not data_args.data_dir:
if data_args.use_tfds:
if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
......@@ -179,7 +182,7 @@ def main():
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad")
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train
......@@ -209,6 +212,8 @@ def main():
else None
)
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
eval_dataset = (
squad_convert_examples_to_features(
examples=eval_examples,
......@@ -223,6 +228,8 @@ def main():
else None
)
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
# Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
......
......@@ -9,6 +9,7 @@ from enum import Enum
from typing import Dict, Optional
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import (
......@@ -35,7 +36,11 @@ class Split(Enum):
def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
):
if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched"
......@@ -50,9 +55,11 @@ def get_tfds(
else:
tfds_name = task_name
ds = tfds.load("glue/" + tfds_name, split=mode.value)
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
return ds
logger = logging.getLogger(__name__)
......@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
"""
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field(
default=128,
metadata={
......@@ -171,13 +179,22 @@ def main():
# Get datasets
train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train
else None
)
eval_dataset = (
get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
)
if training_args.do_eval
else None
......
......@@ -17,7 +17,6 @@
import logging
import os
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
......@@ -185,11 +184,6 @@ def main():
for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] == -1:
label_ids[i, j] = -100
warnings.warn(
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
)
if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])
......
......@@ -146,7 +146,7 @@ if is_tf_available():
"""
features: List[InputFeatures]
pad_token_label_id: int = -1
pad_token_label_id: int = -100
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
......@@ -221,6 +221,8 @@ if is_tf_available():
)
def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset
def __len__(self):
......
......@@ -17,7 +17,6 @@
import functools
import logging
import os
import warnings
from typing import Dict, List, Optional, Union
import h5py
......@@ -174,11 +173,7 @@ class TFTokenClassificationLoss:
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True:
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
This diff is collapsed.
......@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas)
return per_device_batch_size * self.n_replicas
@property
def eval_batch_size(self) -> int:
......@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas)
return per_device_batch_size * self.n_replicas
@property
@tf_required
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment