"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5362bb8a6b86b51851012d9166da339f23cf734d"
Unverified Commit 54f9fbef authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
parent 3f94170a
# Examples # Examples
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2. Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+. Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
Here is the list of all our examples: Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models) - **grouped by task** (all official examples work for multiple models)
......
...@@ -204,6 +204,8 @@ if is_tf_available(): ...@@ -204,6 +204,8 @@ if is_tf_available():
) )
def get_dataset(self): def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset return self.dataset
def __len__(self): def __len__(self):
......
...@@ -21,6 +21,8 @@ import os ...@@ -21,6 +21,8 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import tensorflow as tf
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -68,6 +70,7 @@ class DataTrainingArguments: ...@@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field( data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
) )
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field( max_seq_length: int = field(
default=128, default=128,
metadata={ metadata={
...@@ -170,7 +173,7 @@ def main(): ...@@ -170,7 +173,7 @@ def main():
) )
# Get datasets # Get datasets
if not data_args.data_dir: if data_args.use_tfds:
if data_args.version_2_with_negative: if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically") logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
...@@ -179,7 +182,7 @@ def main(): ...@@ -179,7 +182,7 @@ def main():
except ImportError: except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.") raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad") tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = ( train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False) SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train if training_args.do_train
...@@ -209,6 +212,8 @@ def main(): ...@@ -209,6 +212,8 @@ def main():
else None else None
) )
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
eval_dataset = ( eval_dataset = (
squad_convert_examples_to_features( squad_convert_examples_to_features(
examples=eval_examples, examples=eval_examples,
...@@ -223,6 +228,8 @@ def main(): ...@@ -223,6 +228,8 @@ def main():
else None else None
) )
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
# Initialize our Trainer # Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
......
...@@ -9,6 +9,7 @@ from enum import Enum ...@@ -9,6 +9,7 @@ from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from transformers import ( from transformers import (
...@@ -35,7 +36,11 @@ class Split(Enum): ...@@ -35,7 +36,11 @@ class Split(Enum):
def get_tfds( def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
): ):
if task_name == "mnli-mm" and mode == Split.dev: if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched" tfds_name = "mnli_mismatched"
...@@ -50,9 +55,11 @@ def get_tfds( ...@@ -50,9 +55,11 @@ def get_tfds(
else: else:
tfds_name = task_name tfds_name = task_name
ds = tfds.load("glue/" + tfds_name, split=mode.value) ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name) return ds
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,6 +76,7 @@ class GlueDataTrainingArguments: ...@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
""" """
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())}) task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field( max_seq_length: int = field(
default=128, default=128,
metadata={ metadata={
...@@ -171,13 +179,22 @@ def main(): ...@@ -171,13 +179,22 @@ def main():
# Get datasets # Get datasets
train_dataset = ( train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length) get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train if training_args.do_train
else None else None
) )
eval_dataset = ( eval_dataset = (
get_tfds( get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
) )
if training_args.do_eval if training_args.do_eval
else None else None
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import logging import logging
import os import os
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -185,11 +184,6 @@ def main(): ...@@ -185,11 +184,6 @@ def main():
for i in range(batch_size): for i in range(batch_size):
for j in range(seq_len): for j in range(seq_len):
if label_ids[i, j] == -1:
label_ids[i, j] = -100
warnings.warn(
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
)
if label_ids[i, j] != -100: if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]]) out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]]) preds_list[i].append(label_map[preds[i][j]])
......
...@@ -146,7 +146,7 @@ if is_tf_available(): ...@@ -146,7 +146,7 @@ if is_tf_available():
""" """
features: List[InputFeatures] features: List[InputFeatures]
pad_token_label_id: int = -1 pad_token_label_id: int = -100
# Use cross entropy ignore_index as padding label id so that only # Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later. # real label ids contribute to the loss later.
...@@ -221,6 +221,8 @@ if is_tf_available(): ...@@ -221,6 +221,8 @@ if is_tf_available():
) )
def get_dataset(self): def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
return self.dataset return self.dataset
def __len__(self): def __len__(self):
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import logging import logging
import os import os
import warnings
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import h5py import h5py
...@@ -174,11 +173,7 @@ class TFTokenClassificationLoss: ...@@ -174,11 +173,7 @@ class TFTokenClassificationLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True: active_loss = tf.reshape(labels, (-1,)) != -100
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
This diff is collapsed.
...@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_train_batch_size` is preferred." "version. Using `--per_device_train_batch_size` is preferred."
) )
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas) return per_device_batch_size * self.n_replicas
@property @property
def eval_batch_size(self) -> int: def eval_batch_size(self) -> int:
...@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments):
"version. Using `--per_device_eval_batch_size` is preferred." "version. Using `--per_device_eval_batch_size` is preferred."
) )
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas) return per_device_batch_size * self.n_replicas
@property @property
@tf_required @tf_required
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment