"...git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "77745a84ea7eceab388bc3c0335b0dcd3ecb3395"
Unverified Commit 6f289dc9 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix the TF Trainer gradient accumulation and the TF NER example (#6713)

* Align TF NER example over the PT one

* Fix Dataset call

* Fix gradient accumulation training

* Apply style

* Address Sylvain's comments

* Address Sylvain's comments

* Apply style
parent 41aa2b4e
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -32,7 +33,7 @@ from transformers import ( ...@@ -32,7 +33,7 @@ from transformers import (
TFTrainer, TFTrainer,
TFTrainingArguments, TFTrainingArguments,
) )
from utils_ner import Split, TFNerDataset, get_labels from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,6 +51,9 @@ class ModelArguments: ...@@ -50,6 +51,9 @@ class ModelArguments:
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field( tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
) )
...@@ -102,6 +106,17 @@ def main(): ...@@ -102,6 +106,17 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
) )
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -117,7 +132,7 @@ def main(): ...@@ -117,7 +132,7 @@ def main():
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
# Prepare Token Classification task # Prepare Token Classification task
labels = get_labels(data_args.labels) labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels) num_labels = len(labels)
...@@ -150,7 +165,8 @@ def main(): ...@@ -150,7 +165,8 @@ def main():
# Get datasets # Get datasets
train_dataset = ( train_dataset = (
TFNerDataset( TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
...@@ -163,7 +179,8 @@ def main(): ...@@ -163,7 +179,8 @@ def main():
else None else None
) )
eval_dataset = ( eval_dataset = (
TFNerDataset( TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
...@@ -233,7 +250,8 @@ def main(): ...@@ -233,7 +250,8 @@ def main():
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
test_dataset = TFNerDataset( test_dataset = TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
labels=labels, labels=labels,
......
...@@ -276,7 +276,7 @@ if is_torch_available(): ...@@ -276,7 +276,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
class TFNerDataset: class TFTokenClassificationDataset:
""" """
This will be superseded by a framework-agnostic approach This will be superseded by a framework-agnostic approach
soon. soon.
......
...@@ -174,7 +174,7 @@ class TFTokenClassificationLoss: ...@@ -174,7 +174,7 @@ class TFTokenClassificationLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True: if tf.math.reduce_any(labels == -1):
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1 active_loss = tf.reshape(labels, (-1,)) != -1
else: else:
......
...@@ -620,13 +620,22 @@ class TFTrainer: ...@@ -620,13 +620,22 @@ class TFTrainer:
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables))) self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else: else:
for _ in tf.range(self.args.gradient_accumulation_steps): for _ in tf.range(self.args.gradient_accumulation_steps):
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas] reduced_features = {
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas] k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels) self.training_step(reduced_features, reduced_labels)
features = tf.concat( features = {
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0 k: tf.concat(
[ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]], axis=0,
)
for k, ft in features.items()
}
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
) )
gradients = self.gradient_accumulator.gradients gradients = self.gradient_accumulator.gradients
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment