Commit 2d855973 authored by thomwolf's avatar thomwolf Committed by Morgan Funtowicz
Browse files

add pipeline - train

parent 72c36b9e
import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger from logging import getLogger
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
...@@ -14,8 +14,6 @@ else: ...@@ -14,8 +14,6 @@ else:
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
# TF training parameters # TF training parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2
USE_XLA = False USE_XLA = False
USE_AMP = False USE_AMP = False
...@@ -24,7 +22,7 @@ def train_command_factory(args: Namespace): ...@@ -24,7 +22,7 @@ def train_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
return TrainCommand(args.model) return TrainCommand(args)
class TrainCommand(BaseTransformersCLICommand): class TrainCommand(BaseTransformersCLICommand):
...@@ -38,50 +36,84 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -38,50 +36,84 @@ class TrainCommand(BaseTransformersCLICommand):
""" """
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.') train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
train_parser.add_argument('--train_data', type=str, required=True, train_parser.add_argument('--train_data', type=str, required=True,
help='path to train (and optionally evaluation) dataset.') help="path to train (and optionally evaluation) dataset as a csv with "
"tab separated labels and sentences.")
train_parser.add_argument('--column_label', type=int, default=0,
help='Column of the dataset csv file with example labels.')
train_parser.add_argument('--column_text', type=int, default=1,
help='Column of the dataset csv file with example texts.')
train_parser.add_argument('--column_id', type=int, default=2,
help='Column of the dataset csv file with example ids.')
train_parser.add_argument('--validation_data', type=str, default='',
help='path to validation dataset.')
train_parser.add_argument('--validation_split', type=float, default=0.1,
help="if validation dataset is not provided, fraction of train dataset "
"to use as validation dataset.")
train_parser.add_argument('--output', type=str, default='./',
help='path to saved the trained model.')
train_parser.add_argument('--task', type=str, default='text_classification', train_parser.add_argument('--task', type=str, default='text_classification',
help='Task to train the model on.') help='Task to train the model on.')
train_parser.add_argument('--model', type=str, default='bert-base-uncased', train_parser.add_argument('--model', type=str, default='bert-base-uncased',
help='Model\'s name or path to stored model.') help='Model\'s name or path to stored model.')
train_parser.add_argument('--valid_data', type=str, default='', train_parser.add_argument('--train_batch_size', type=int, default=32,
help='path to validation dataset.') help='Batch size for training.')
train_parser.add_argument('--valid_data_ratio', type=float, default=0.1, train_parser.add_argument('--valid_batch_size', type=int, default=64,
help="if validation dataset is not provided, fraction of train dataset " help='Batch size for validation.')
"to use as validation dataset.") train_parser.add_argument('--learning_rate', type=float, default=3e-5,
help="Learning rate.")
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08,
help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory) train_parser.set_defaults(func=train_command_factory)
def __init__(self, model_name: str, task: str, train_data: str, def __init__(self, args: Namespace):
valid_data: str, valid_data_ratio: float): self.logger = getLogger('transformers-cli/training')
self._logger = getLogger('transformers-cli/training')
self._framework = 'tf' if is_tf_available() else 'torch' self.framework = 'tf' if is_tf_available() else 'torch'
self._logger.info('Loading model {}'.format(model_name)) os.makedirs(args.output)
self._model_name = model_name self.output = args.output
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
if task == 'text_classification': self.column_label = args.column_label
self._model = SequenceClassifModel.from_pretrained(model_name) self.column_text = args.column_text
elif task == 'token_classification': self.column_id = args.column_id
self.logger.info('Loading model {}'.format(args.model_name))
self.model_name = args.model_name
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if args.task == 'text_classification':
self.model = SequenceClassifModel.from_pretrained(args.model_name)
elif args.task == 'token_classification':
raise NotImplementedError raise NotImplementedError
elif task == 'question_answering': elif args.task == 'question_answering':
raise NotImplementedError raise NotImplementedError
dataset = SingleSentenceClassificationProcessor.create_from_csv(train_data) self.logger.info('Loading dataset from {}'.format(args.train_data))
num_data_samples = len(SingleSentenceClassificationProcessor) dataset = SingleSentenceClassificationProcessor.create_from_csv(args.train_data)
if valid_data: num_data_samples = len(dataset)
self._train_dataset = dataset if args.validation_data:
self._num_train_samples = num_data_samples self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
self._valid_dataset = SingleSentenceClassificationProcessor.create_from_csv(valid_data) self.valid_dataset = SingleSentenceClassificationProcessor.create_from_csv(args.validation_data)
self._num_valid_samples = len(self._valid_dataset) self.num_valid_samples = len(self.valid_dataset)
self.train_dataset = dataset
self.num_train_samples = num_data_samples
else: else:
assert 0.0 < valid_data_ratio < 1.0, "--valid_data_ratio should be between 0.0 and 1.0" assert 0.0 < args.validation_split < 1.0, "--validation_split should be between 0.0 and 1.0"
self._num_valid_samples = num_data_samples * valid_data_ratio self.num_valid_samples = num_data_samples * args.validation_split
self._num_train_samples = num_data_samples - self._num_valid_samples self.num_train_samples = num_data_samples - self.num_valid_samples
self._train_dataset = dataset[self._num_train_samples] self.train_dataset = dataset[self.num_train_samples]
self._valid_dataset = dataset[self._num_valid_samples] self.valid_dataset = dataset[self.num_valid_samples]
self.train_batch_size = args.train_batch_size
self.valid_batch_size = args.valid_batch_size
self.learning_rate = args.learning_rate
self.adam_epsilon = args.adam_epsilon
def run(self): def run(self):
if self._framework == 'tf': if self.framework == 'tf':
return self.run_tf() return self.run_tf()
return self.run_torch() return self.run_torch()
...@@ -95,27 +127,28 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -95,27 +127,28 @@ class TrainCommand(BaseTransformersCLICommand):
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP}) tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
# Prepare dataset as a tf.train_data.Dataset instance # Prepare dataset as a tf.train_data.Dataset instance
train_dataset = convert_examples_to_features(self._train_dataset, self._tokenizer, mode='sequence_classification') self.logger.info('Tokenizing and processing dataset')
valid_dataset = convert_examples_to_features(self._valid_dataset, self._tokenizer, mode='sequence_classification') train_dataset = self.train_dataset.get_features(self.tokenizer)
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1) valid_dataset = self.valid_dataset.get_features(self.tokenizer)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE) train_dataset = train_dataset.shuffle(128).batch(self.train_batch_size).repeat(-1)
valid_dataset = valid_dataset.batch(self.valid_batch_size)
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08) opt = tf.keras.optimizers.Adam(learning_rate=args.learning_rate, epsilon=self.adam_epsilon)
if USE_AMP: if USE_AMP:
# loss scaling is currently required when using mixed precision # loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic') opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=opt, loss=loss, metrics=[metric]) self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
# Train and evaluate using tf.keras.Model.fit() # Train and evaluate using tf.keras.Model.fit()
train_steps = train_examples//BATCH_SIZE train_steps = self.num_train_samples//self.train_batch_size
valid_steps = valid_examples//EVAL_BATCH_SIZE valid_steps = self.num_valid_samples//self.valid_batch_size
history = model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps, self.logger.info('Training model')
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
validation_data=valid_dataset, validation_steps=valid_steps) validation_data=valid_dataset, validation_steps=valid_steps)
# Save TF2 model # Save trained model
os.makedirs('./save/', exist_ok=True) self.model.save_pretrained(self.output)
model.save_pretrained('./save/')
...@@ -18,6 +18,11 @@ import csv ...@@ -18,6 +18,11 @@ import csv
import sys import sys
import copy import copy
import json import json
import logging
from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
""" """
...@@ -64,7 +69,7 @@ class InputFeatures(object): ...@@ -64,7 +69,7 @@ class InputFeatures(object):
label: Label corresponding to the input label: Label corresponding to the input
""" """
def __init__(self, input_ids, attention_mask, token_type_ids, label): def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
self.token_type_ids = token_type_ids self.token_type_ids = token_type_ids
...@@ -86,34 +91,6 @@ class InputFeatures(object): ...@@ -86,34 +91,6 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def tfds_map(self, example):
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
This method converts examples to the correct format."""
if len(self.get_labels()) > 1:
example.label = self.get_labels()[int(example.label)]
return example
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
...@@ -129,15 +106,11 @@ class DataProcessor(object): ...@@ -129,15 +106,11 @@ class DataProcessor(object):
class SingleSentenceClassificationProcessor(DataProcessor): class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set.""" """ Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None): def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
self.labels = [] if labels is None else labels self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples self.examples = [] if examples is None else examples
self.mode = mode
@classmethod self.verbose = verbose
def create_from_csv(cls, file_name):
processor = cls()
processor.add_examples_from_csv(file_name)
return processor
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
...@@ -148,30 +121,40 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -148,30 +121,40 @@ class SingleSentenceClassificationProcessor(DataProcessor):
examples=self.examples[idx]) examples=self.examples[idx])
return self.examples[idx] return self.examples[idx]
def get_labels(self): @classmethod
"""Gets the list of labels for this data set.""" def create_from_csv(cls, file_name, **kwargs):
return self.labels processor = cls(**kwargs)
processor.add_examples_from_csv(file_name)
return processor
def add_examples_from_csv(self, file_name): def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
overwrite_labels=False, overwrite_examples=False):
lines = self._read_tsv(file_name) lines = self._read_tsv(file_name)
self.add_examples_from_lines(lines) texts = []
labels = []
def add_examples_from_lines(self, lines, split_name='', overwrite_labels=False, overwrite_examples=False): ids = []
"""Creates examples for the training and dev sets."""
added_labels = set()
examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) > 2: texts.append(line[column_text])
guid = "%s-%s" % (split_name, line[0]) if split_name else line[0] labels.append(line[column_label])
label = line[1] if column_id is not None:
text_a = line[2] ids.append(line[column_id])
else: else:
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
label = line[0] ids.append(guid)
text_a = line[1]
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
def add_examples(self, texts, labels, ids=None, overwrite_labels=False, overwrite_examples=False):
if ids is None:
ids = [None] * len(texts)
assert len(texts) == len(labels)
assert len(texts) == len(ids)
examples = []
added_labels = set()
for (text, label, guid) in zip(texts, labels, ids):
added_labels.add(label) added_labels.add(label)
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
# Update examples # Update examples
if overwrite_examples: if overwrite_examples:
...@@ -187,19 +170,23 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -187,19 +170,23 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples return self.examples
@classmethod
def create_from_examples(cls, texts, labels, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts, labels)
return processor
def convert_examples_to_features(examples, tokenizer, def get_features(self,
mode='sequence_classification', tokenizer,
max_length=512, max_length=None,
pad_on_left=False, pad_on_left=False,
pad_token=0, pad_token=0,
pad_token_segment_id=0, mask_padding_with_zero=True,
mask_padding_with_zero=True): return_tensors=None):
""" """
Loads a data file into a list of ``InputFeatures`` Convert examples in a list of ``InputFeatures``
Args: Args:
examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
tokenizer: Instance of a tokenizer that will tokenize the examples tokenizer: Instance of a tokenizer that will tokenize the examples
max_length: Maximum example length max_length: Maximum example length
task: GLUE task task: GLUE task
...@@ -207,7 +194,6 @@ def convert_examples_to_features(examples, tokenizer, ...@@ -207,7 +194,6 @@ def convert_examples_to_features(examples, tokenizer,
output_mode: String indicating the output mode. Either ``regression`` or ``classification`` output_mode: String indicating the output mode. Either ``regression`` or ``classification``
pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
pad_token: Padding token pad_token: Padding token
pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
actual values) actual values)
...@@ -218,92 +204,95 @@ def convert_examples_to_features(examples, tokenizer, ...@@ -218,92 +204,95 @@ def convert_examples_to_features(examples, tokenizer,
a list of task-specific ``InputFeatures`` which can be fed to the model. a list of task-specific ``InputFeatures`` which can be fed to the model.
""" """
is_tf_dataset = False
if is_tf_available() and isinstance(examples, tf.data.Dataset):
is_tf_dataset = True
if task is not None:
processor = glue_processors[task]()
if label_list is None:
label_list = processor.get_labels()
logger.info("Using label list %s for task %s" % (label_list, task))
if output_mode is None:
output_mode = glue_output_modes[task]
logger.info("Using output mode %s for task %s" % (output_mode, task))
label_map = {label: i for i, label in enumerate(label_list)} label_map = {label: i for i, label in enumerate(self.labels)}
features = [] all_input_ids = []
for (ex_index, example) in enumerate(examples): for (ex_index, example) in enumerate(self.examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logger.info("Writing example %d" % (ex_index)) logger.info("Tokenizing example %d", ex_index)
if is_tf_dataset:
example = processor.get_example_from_tensor_dict(example)
inputs = tokenizer.encode_plus( input_ids = tokenizer.encode(
example.text_a, example.text_a,
example.text_b,
add_special_tokens=True, add_special_tokens=True,
max_length=max_length, max_length=min(max_length, tokenizer.max_len),
) )
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] all_input_ids.append(input_ids)
batch_length = max(len(input_ids) for input_ids in all_input_ids)
features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, examples)):
if ex_index % 10000 == 0:
logger.info("Writing example %d", ex_index)
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to. # tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length. # Zero-pad up to the sequence length.
padding_length = max_length - len(input_ids) padding_length = batch_length - len(input_ids)
if pad_on_left: if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
else: else:
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
if output_mode == "classification": if self.mode == "classification":
label = label_map[example.label] label = label_map[example.label]
elif output_mode == "regression": elif self.mode == "regression":
label = float(example.label) label = float(example.label)
else: else:
raise KeyError(output_mode) raise ValueError(self.mode)
if ex_index < 5: if ex_index < 5 and self.verbose:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info("guid: %s" % (example.guid))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(
InputFeatures(input_ids=input_ids, InputFeatures(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label)) label=label))
if is_tf_available() and is_tf_dataset: if return_tensors is None:
return features
elif return_tensors == 'tf':
if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask, 'attention_mask': ex.attention_mask},
'token_type_ids': ex.token_type_ids},
ex.label) ex.label)
return tf.data.Dataset.from_generator(gen, dataset = tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32, ({'input_ids': tf.int32,
'attention_mask': tf.int32, 'attention_mask': tf.int32},
'token_type_ids': tf.int32},
tf.int64), tf.int64),
({'input_ids': tf.TensorShape([None]), ({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None]), 'attention_mask': tf.TensorShape([None])},
'token_type_ids': tf.TensorShape([None])},
tf.TensorShape([]))) tf.TensorShape([])))
return dataset
return features elif return_tensors == 'pt':
if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch
from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif self.mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
return dataset
else:
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Pipeline class: Tokenizer + Model. """
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import logging
from .modeling_auto import (AutoModel, AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelWithLMHead)
from .tokenization_auto import AutoTokenizer
from .file_utils import add_start_docstrings, is_tf_available, is_torch_available
from .data.processors import SingleSentenceClassificationProcessor
if is_tf_available():
import tensorflow as tf
if is_torch_available():
import torch
logger = logging.getLogger(__name__)
# TF training parameters
USE_XLA = False
USE_AMP = False
class TextClassificationPipeline(object):
r"""
:class:`~transformers.TextClassificationPipeline` is a class encapsulating a pretrained model and
its tokenizer and will be instantiated as one of the base model classes of the library
when created with the `Pipeline.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
"""
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
if is_tf_available():
self.framework = 'tf'
elif is_torch_available():
self.framework = 'pt'
else:
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
self.is_compiled = False
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Params:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
Examples::
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
# Extract tokenizer and model arguments
tokenizer_kwargs = {}
for key in kwargs:
if key.startswith('tokenizer_'):
# Specific to the tokenizer
tokenizer_kwargs[key.replace('tokenizer_', '')] = kwargs.pop(key)
elif not key.startswith('model_'):
# used for both the tokenizer and the model
tokenizer_kwargs[key] = kwargs[key]
model_kwargs = {}
for key in kwargs:
if key.startswith('model_'):
# Specific to the model
model_kwargs[key.replace('model_', '')] = kwargs.pop(key)
elif not key.startswith('tokenizer_'):
# used for both the tokenizer and the model
model_kwargs[key] = kwargs[key]
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **tokenizer_kwargs)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
return cls(tokenizer, model)
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory))
return
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
def compile(self, learning_rate=3e-5, epsilon=1e-8):
if self.framework == 'tf':
logger.info('Preparing model')
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
else:
raise NotImplementedError
self.is_compiled = True
def prepare_data(self, train_samples_text, train_samples_labels,
valid_samples_text=None, valid_samples_labels=None,
validation_split=0.1):
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
train_samples_labels)
num_data_samples = len(dataset)
if valid_samples_text is not None and valid_samples_labels is not None:
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
valid_samples_labels)
num_valid_samples = len(valid_dataset)
train_dataset = dataset
num_train_samples = num_data_samples
else:
assert 0.0 < validation_split < 1.0, "validation_split should be between 0.0 and 1.0"
num_valid_samples = int(num_data_samples * validation_split)
num_train_samples = num_data_samples - num_valid_samples
train_dataset = dataset[num_train_samples]
valid_dataset = dataset[num_valid_samples]
logger.info('Tokenizing and processing dataset')
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
def fit(self, train_samples_text, train_samples_labels,
valid_samples_text=None, valid_samples_labels=None,
train_batch_size=None, valid_batch_size=None,
validation_split=0.1,
**kwargs):
if not self.is_compiled:
self.compile()
datasets = self.prepare_data(train_samples_text, train_samples_labels,
valid_samples_text, valid_samples_labels,
validation_split)
train_dataset, valid_dataset, num_train_samples, num_valid_samples = datasets
train_steps = num_train_samples//train_batch_size
valid_steps = num_valid_samples//valid_batch_size
if self.framework == 'tf':
# Prepare dataset as a tf.train_data.Dataset instance
train_dataset = train_dataset.shuffle(128).batch(train_batch_size).repeat(-1)
valid_dataset = valid_dataset.batch(valid_batch_size)
logger.info('Training TF 2.0 model')
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
validation_data=valid_dataset, validation_steps=valid_steps, **kwargs)
else:
raise NotImplementedError
def __call__(self, text):
inputs = self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors=self.framework)
if self.framework == 'tf':
# TODO trace model
predictions = self.model(**inputs)[0]
else:
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment