Unverified Commit 6f289dc9 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix the TF Trainer gradient accumulation and the TF NER example (#6713)

* Align TF NER example over the PT one

* Fix Dataset call

* Fix gradient accumulation training

* Apply style

* Address Sylvain's comments

* Address Sylvain's comments

* Apply style
parent 41aa2b4e
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -32,7 +33,7 @@ from transformers import ( ...@@ -32,7 +33,7 @@ from transformers import (
TFTrainer, TFTrainer,
TFTrainingArguments, TFTrainingArguments,
) )
from utils_ner import Split, TFNerDataset, get_labels from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,6 +51,9 @@ class ModelArguments: ...@@ -50,6 +51,9 @@ class ModelArguments:
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field( tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
) )
...@@ -102,6 +106,17 @@ def main(): ...@@ -102,6 +106,17 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
) )
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -117,7 +132,7 @@ def main(): ...@@ -117,7 +132,7 @@ def main():
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
# Prepare Token Classification task # Prepare Token Classification task
labels = get_labels(data_args.labels) labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels) num_labels = len(labels)
...@@ -150,7 +165,8 @@ def main(): ...@@ -150,7 +165,8 @@ def main():
# Get datasets # Get datasets
train_dataset = ( train_dataset = (
TFNerDataset( TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
...@@ -163,7 +179,8 @@ def main(): ...@@ -163,7 +179,8 @@ def main():
else None else None
) )
eval_dataset = ( eval_dataset = (
TFNerDataset( TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
...@@ -233,7 +250,8 @@ def main(): ...@@ -233,7 +250,8 @@ def main():
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
test_dataset = TFNerDataset( test_dataset = TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
......
...@@ -276,7 +276,7 @@ if is_torch_available(): ...@@ -276,7 +276,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
class TFNerDataset: class TFTokenClassificationDataset:
""" """
This will be superseded by a framework-agnostic approach This will be superseded by a framework-agnostic approach
soon. soon.
......
...@@ -174,7 +174,7 @@ class TFTokenClassificationLoss: ...@@ -174,7 +174,7 @@ class TFTokenClassificationLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True: if tf.math.reduce_any(labels == -1):
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1 active_loss = tf.reshape(labels, (-1,)) != -1
else: else:
......
...@@ -620,13 +620,22 @@ class TFTrainer: ...@@ -620,13 +620,22 @@ class TFTrainer:
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables))) self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else: else:
for _ in tf.range(self.args.gradient_accumulation_steps): for _ in tf.range(self.args.gradient_accumulation_steps):
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas] reduced_features = {
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas] k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels) self.training_step(reduced_features, reduced_labels)
features = tf.concat( features = {
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0 k: tf.concat(
[ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]], axis=0,
)
for k, ft in features.items()
}
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
) )
gradients = self.gradient_accumulator.gradients gradients = self.gradient_accumulator.gradients
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment