"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "06f1692b023a701ab2bb443fa4f0bdd58c6bd234"
Commit 31a3a73e authored by thomwolf's avatar thomwolf Committed by Morgan Funtowicz
Browse files

updating CLI

parent 7c169756
...@@ -3,6 +3,8 @@ from argparse import ArgumentParser ...@@ -3,6 +3,8 @@ from argparse import ArgumentParser
from transformers.commands.serving import ServeCommand from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand
from transformers.commands.convert import ConvertCommand
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]') parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
...@@ -11,6 +13,8 @@ if __name__ == '__main__': ...@@ -11,6 +13,8 @@ if __name__ == '__main__':
# Register commands # Register commands
ServeCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser)
ConvertCommand.register_subcommand(commands_parser)
# Let's go # Let's go
args = parser.parse_args() args = parser.parse_args()
......
...@@ -25,7 +25,6 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH ...@@ -25,7 +25,6 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
from .data import (is_sklearn_available, from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor, InputExample, InputFeatures, DataProcessor,
SingleSentenceClassificationProcessor, SingleSentenceClassificationProcessor,
convert_examples_to_features,
glue_output_modes, glue_convert_examples_to_features, glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels, glue_processors, glue_tasks_num_labels,
xnli_output_modes, xnli_processors, xnli_tasks_num_labels, xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
...@@ -66,6 +65,9 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO ...@@ -66,6 +65,9 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Pipelines
from .pipeline import TextClassificationPipeline
# Modeling # Modeling
if is_torch_available(): if is_torch_available():
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
......
...@@ -3,14 +3,11 @@ from argparse import ArgumentParser, Namespace ...@@ -3,14 +3,11 @@ from argparse import ArgumentParser, Namespace
from logging import getLogger from logging import getLogger
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers import (AutoTokenizer, is_tf_available, is_torch_available, from transformers import (is_tf_available, is_torch_available,
SingleSentenceClassificationProcessor, TextClassificationPipeline,
convert_examples_to_features) SingleSentenceClassificationProcessor as Processor)
if is_tf_available():
from transformers import TFAutoModelForSequenceClassification as SequenceClassifModel if not is_tf_available() and not is_torch_available():
elif is_torch_available():
from transformers import AutoModelForSequenceClassification as SequenceClassifModel
else:
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
# TF training parameters # TF training parameters
...@@ -35,16 +32,18 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -35,16 +32,18 @@ class TrainCommand(BaseTransformersCLICommand):
:return: :return:
""" """
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.') train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
train_parser.add_argument('--train_data', type=str, required=True, train_parser.add_argument('--train_data', type=str, required=True,
help="path to train (and optionally evaluation) dataset as a csv with " help="path to train (and optionally evaluation) dataset as a csv with "
"tab separated labels and sentences.") "tab separated labels and sentences.")
train_parser.add_argument('--column_label', type=int, default=0, train_parser.add_argument('--column_label', type=int, default=0,
help='Column of the dataset csv file with example labels.') help='Column of the dataset csv file with example labels.')
train_parser.add_argument('--column_text', type=int, default=1, train_parser.add_argument('--column_text', type=int, default=1,
help='Column of the dataset csv file with example texts.') help='Column of the dataset csv file with example texts.')
train_parser.add_argument('--column_id', type=int, default=2, train_parser.add_argument('--column_id', type=int, default=2,
help='Column of the dataset csv file with example ids.') help='Column of the dataset csv file with example ids.')
train_parser.add_argument('--skip_first_row', action='store_true',
help='Skip the first row of the csv file (headers).')
train_parser.add_argument('--validation_data', type=str, default='', train_parser.add_argument('--validation_data', type=str, default='',
help='path to validation dataset.') help='path to validation dataset.')
...@@ -74,39 +73,38 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -74,39 +73,38 @@ class TrainCommand(BaseTransformersCLICommand):
self.framework = 'tf' if is_tf_available() else 'torch' self.framework = 'tf' if is_tf_available() else 'torch'
os.makedirs(args.output) os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output)
self.output = args.output self.output = args.output
self.column_label = args.column_label self.column_label = args.column_label
self.column_text = args.column_text self.column_text = args.column_text
self.column_id = args.column_id self.column_id = args.column_id
self.logger.info('Loading model {}'.format(args.model_name)) self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
self.model_name = args.model_name
self.pipeline = AutoTokenizer.from_pretrained(args.model_name)
if args.task == 'text_classification': if args.task == 'text_classification':
self.model = SequenceClassifModel.from_pretrained(args.model_name) self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == 'token_classification': elif args.task == 'token_classification':
raise NotImplementedError raise NotImplementedError
elif args.task == 'question_answering': elif args.task == 'question_answering':
raise NotImplementedError raise NotImplementedError
self.logger.info('Loading dataset from {}'.format(args.train_data)) self.logger.info('Loading dataset from {}'.format(args.train_data))
dataset = SingleSentenceClassificationProcessor.create_from_csv(args.train_data) self.train_dataset = Processor.create_from_csv(args.train_data,
num_data_samples = len(dataset) column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row)
self.valid_dataset = None
if args.validation_data: if args.validation_data:
self.logger.info('Loading validation dataset from {}'.format(args.validation_data)) self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
self.valid_dataset = SingleSentenceClassificationProcessor.create_from_csv(args.validation_data) self.valid_dataset = Processor.create_from_csv(args.validation_data,
self.num_valid_samples = len(self.valid_dataset) column_label=args.column_label,
self.train_dataset = dataset column_text=args.column_text,
self.num_train_samples = num_data_samples column_id=args.column_id,
else: skip_first_row=args.skip_first_row)
assert 0.0 < args.validation_split < 1.0, "--validation_split should be between 0.0 and 1.0"
self.num_valid_samples = num_data_samples * args.validation_split
self.num_train_samples = num_data_samples - self.num_valid_samples
self.train_dataset = dataset[self.num_train_samples]
self.valid_dataset = dataset[self.num_valid_samples]
self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size self.train_batch_size = args.train_batch_size
self.valid_batch_size = args.valid_batch_size self.valid_batch_size = args.valid_batch_size
self.learning_rate = args.learning_rate self.learning_rate = args.learning_rate
...@@ -121,34 +119,13 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -121,34 +119,13 @@ class TrainCommand(BaseTransformersCLICommand):
raise NotImplementedError raise NotImplementedError
def run_tf(self): def run_tf(self):
import tensorflow as tf self.pipeline.fit(self.train_dataset,
validation_data=self.valid_dataset,
tf.config.optimizer.set_jit(USE_XLA) validation_split=self.validation_split,
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP}) learning_rate=self.learning_rate,
adam_epsilon=self.adam_epsilon,
# Prepare dataset as a tf.train_data.Dataset instance train_batch_size=self.train_batch_size,
self.logger.info('Tokenizing and processing dataset') valid_batch_size=self.valid_batch_size)
train_dataset = self.train_dataset.get_features(self.tokenizer)
valid_dataset = self.valid_dataset.get_features(self.tokenizer) # Save trained pipeline
train_dataset = train_dataset.shuffle(128).batch(self.train_batch_size).repeat(-1) self.pipeline.save_pretrained(self.output)
valid_dataset = valid_dataset.batch(self.valid_batch_size)
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=args.learning_rate, epsilon=self.adam_epsilon)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
# Train and evaluate using tf.keras.Model.fit()
train_steps = self.num_train_samples//self.train_batch_size
valid_steps = self.num_valid_samples//self.valid_batch_size
self.logger.info('Training model')
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
validation_data=valid_dataset, validation_steps=valid_steps)
# Save trained model
self.model.save_pretrained(self.output)
...@@ -122,14 +122,30 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -122,14 +122,30 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples[idx] return self.examples[idx]
@classmethod @classmethod
def create_from_csv(cls, file_name, **kwargs): def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
column_id=None, skip_first_row=False, **kwargs):
processor = cls(**kwargs) processor = cls(**kwargs)
processor.add_examples_from_csv(file_name) processor.add_examples_from_csv(file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True)
return processor
@classmethod
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None, def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
overwrite_labels=False, overwrite_examples=False): skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
lines = self._read_tsv(file_name) lines = self._read_tsv(file_name)
if skip_first_row:
lines = lines[1:]
texts = [] texts = []
labels = [] labels = []
ids = [] ids = []
...@@ -144,15 +160,21 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -144,15 +160,21 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples) return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
def add_examples(self, texts, labels, ids=None, overwrite_labels=False, overwrite_examples=False): def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
overwrite_labels=False, overwrite_examples=False):
assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None: if ids is None:
ids = [None] * len(texts) ids = [None] * len(texts_or_text_and_labels)
assert len(texts) == len(labels) if labels is None:
assert len(texts) == len(ids) labels = [None] * len(texts_or_text_and_labels)
examples = [] examples = []
added_labels = set() added_labels = set()
for (text, label, guid) in zip(texts, labels, ids): for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
text, label = text_or_text_and_label
else:
text = text_or_text_and_label
added_labels.add(label) added_labels.add(label)
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
...@@ -170,12 +192,6 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -170,12 +192,6 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples return self.examples
@classmethod
def create_from_examples(cls, texts, labels, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts, labels)
return processor
def get_features(self, def get_features(self,
tokenizer, tokenizer,
max_length=None, max_length=None,
...@@ -204,6 +220,8 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -204,6 +220,8 @@ class SingleSentenceClassificationProcessor(DataProcessor):
a list of task-specific ``InputFeatures`` which can be fed to the model. a list of task-specific ``InputFeatures`` which can be fed to the model.
""" """
if max_length is None:
max_length = tokenizer.max_len
label_map = {label: i for i, label in enumerate(self.labels)} label_map = {label: i for i, label in enumerate(self.labels)}
......
...@@ -22,6 +22,8 @@ import logging ...@@ -22,6 +22,8 @@ import logging
import os import os
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.saving import hdf5_format
import h5py
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
...@@ -206,6 +208,9 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -206,6 +208,9 @@ class TFPreTrainedModel(tf.keras.Model):
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
...@@ -229,6 +234,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -229,6 +234,7 @@ class TFPreTrainedModel(tf.keras.Model):
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False) resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config # Load config
if config is None: if config is None:
...@@ -304,6 +310,31 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -304,6 +310,31 @@ class TFPreTrainedModel(tf.keras.Model):
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
# Check if the models are the same to output loading informations
with h5py.File(resolved_archive_file, 'r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, 'layer_names'))
model_layer_names = set(layer.name for layer in model.layers)
missing_keys = list(model_layer_names - hdf5_layer_names)
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
if len(missing_keys) > 0:
logger.info("Layers of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Layers from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading weights for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if output_loading_info:
loading_info = {"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs}
return model, loading_info
return model return model
class TFConv1D(tf.keras.layers.Layer): class TFConv1D(tf.keras.layers.Layer):
......
...@@ -17,18 +17,22 @@ ...@@ -17,18 +17,22 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import logging import logging
import six
from .modeling_auto import (AutoModel, AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelWithLMHead)
from .tokenization_auto import AutoTokenizer from .tokenization_auto import AutoTokenizer
from .file_utils import add_start_docstrings, is_tf_available, is_torch_available from .file_utils import add_start_docstrings, is_tf_available, is_torch_available
from .data.processors import SingleSentenceClassificationProcessor from .data.processors import SingleSentenceClassificationProcessor
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForQuestionAnswering,
TFAutoModelForSequenceClassification,
TFAutoModelWithLMHead)
if is_torch_available(): if is_torch_available():
import torch import torch
from .modeling_auto import (AutoModel, AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelWithLMHead)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,12 +65,6 @@ class TextClassificationPipeline(object): ...@@ -61,12 +65,6 @@ class TextClassificationPipeline(object):
def __init__(self, tokenizer, model, is_compiled=False, is_trained=False): def __init__(self, tokenizer, model, is_compiled=False, is_trained=False):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.model = model self.model = model
if is_tf_available():
self.framework = 'tf'
elif is_torch_available():
self.framework = 'pt'
else:
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
self.is_compiled = is_compiled self.is_compiled = is_compiled
self.is_trained = is_trained self.is_trained = is_trained
...@@ -94,9 +92,12 @@ class TextClassificationPipeline(object): ...@@ -94,9 +92,12 @@ class TextClassificationPipeline(object):
# used for both the tokenizer and the model # used for both the tokenizer and the model
model_kwargs[key] = kwargs[key] model_kwargs[key] = kwargs[key]
model_kwargs['output_loading_info'] = True
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **tokenizer_kwargs) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **tokenizer_kwargs)
model, loading_info = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs) model_kwargs['output_loading_info'] = True
if is_tf_available():
model, loading_info = TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
else:
model, loading_info = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
return cls(tokenizer, model, is_trained=bool(not loading_info['missing_keys'])) return cls(tokenizer, model, is_trained=bool(not loading_info['missing_keys']))
...@@ -109,36 +110,42 @@ class TextClassificationPipeline(object): ...@@ -109,36 +110,42 @@ class TextClassificationPipeline(object):
self.tokenizer.save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory)
def prepare_data(self, train_samples_text, train_samples_labels, def prepare_data(self, x, y=None,
valid_samples_text=None, valid_samples_labels=None, validation_data=None,
validation_split=0.1, **kwargs): validation_split=0.1, **kwargs):
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text, dataset = x
train_samples_labels) if not isinstance(x, SingleSentenceClassificationProcessor):
dataset = SingleSentenceClassificationProcessor.create_from_examples(x, y)
num_data_samples = len(dataset) num_data_samples = len(dataset)
if valid_samples_text is not None and valid_samples_labels is not None:
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text, if validation_data is not None:
valid_samples_labels) valid_dataset = validation_data
if not isinstance(validation_data, SingleSentenceClassificationProcessor):
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(validation_data)
num_valid_samples = len(valid_dataset) num_valid_samples = len(valid_dataset)
train_dataset = dataset train_dataset = dataset
num_train_samples = num_data_samples num_train_samples = num_data_samples
else: else:
assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0" assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0"
num_valid_samples = int(num_data_samples * validation_split) num_valid_samples = max(int(num_data_samples * validation_split), 1)
num_train_samples = num_data_samples - num_valid_samples num_train_samples = num_data_samples - num_valid_samples
train_dataset = dataset[num_train_samples] train_dataset = dataset[num_valid_samples:]
valid_dataset = dataset[num_valid_samples] valid_dataset = dataset[:num_valid_samples]
logger.info('Tokenizing and processing dataset') logger.info('Tokenizing and processing dataset')
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework) train_dataset = train_dataset.get_features(self.tokenizer,
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework) return_tensors='tf' if is_tf_available() else 'pt')
return train_dataset, valid_dataset, num_train_samples, num_valid_samples valid_dataset = valid_dataset.get_features(self.tokenizer,
return_tensors='tf' if is_tf_available() else 'pt')
return train_dataset, valid_dataset
def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs): def compile(self, learning_rate=3e-5, adam_epsilon=1e-8, **kwargs):
if self.framework == 'tf': if is_tf_available():
logger.info('Preparing model') logger.info('Preparing model')
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=adam_epsilon)
if USE_AMP: if USE_AMP:
# loss scaling is currently required when using mixed precision # loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic') opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
...@@ -150,39 +157,37 @@ class TextClassificationPipeline(object): ...@@ -150,39 +157,37 @@ class TextClassificationPipeline(object):
self.is_compiled = True self.is_compiled = True
def fit(self, train_samples_text=None, train_samples_labels=None, def fit(self, X=None, y=None,
valid_samples_text=None, valid_samples_labels=None, validation_data=None,
train_batch_size=None, valid_batch_size=None,
validation_split=0.1, validation_split=0.1,
train_batch_size=None,
valid_batch_size=None,
**kwargs): **kwargs):
# Generic compatibility with sklearn and Keras
if 'y' in kwargs and train_samples_labels is None:
train_samples_labels = kwargs.pop('y')
if 'X' in kwargs and train_samples_text is None:
train_samples_text = kwargs.pop('X')
if not self.is_compiled: if not self.is_compiled:
self.compile(**kwargs) self.compile(**kwargs)
datasets = self.prepare_data(train_samples_text, train_samples_labels, train_dataset, valid_dataset = self.prepare_data(X, y=y,
valid_samples_text, valid_samples_labels, validation_data=validation_data,
validation_split) validation_split=validation_split)
train_dataset, valid_dataset, num_train_samples, num_valid_samples = datasets num_train_samples = len(train_dataset)
num_valid_samples = len(valid_dataset)
train_steps = num_train_samples//train_batch_size train_steps = num_train_samples//train_batch_size
valid_steps = num_valid_samples//valid_batch_size valid_steps = num_valid_samples//valid_batch_size
if self.framework == 'tf': if is_tf_available():
# Prepare dataset as a tf.train_data.Dataset instance # Prepare dataset as a tf.train_data.Dataset instance
train_dataset = train_dataset.shuffle(128).batch(train_batch_size).repeat(-1) train_dataset = train_dataset.shuffle(128).batch(train_batch_size).repeat(-1)
valid_dataset = valid_dataset.batch(valid_batch_size) valid_dataset = valid_dataset.batch(valid_batch_size)
logger.info('Training TF 2.0 model') logger.info('Training TF 2.0 model')
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps, history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
validation_data=valid_dataset, validation_steps=valid_steps, **kwargs) validation_data=valid_dataset, validation_steps=valid_steps,
**kwargs)
else: else:
raise NotImplementedError raise NotImplementedError
self.is_trained = True self.is_trained = True
...@@ -210,9 +215,11 @@ class TextClassificationPipeline(object): ...@@ -210,9 +215,11 @@ class TextClassificationPipeline(object):
if not self.is_trained: if not self.is_trained:
logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.") logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.")
inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework) inputs = self.tokenizer.batch_encode_plus(texts,
add_special_tokens=True,
return_tensors='tf' if is_tf_available() else 'pt')
if self.framework == 'tf': if is_tf_available():
# TODO trace model # TODO trace model
predictions = self.model(**inputs)[0] predictions = self.model(**inputs)[0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment