Commit 7c169756 authored by thomwolf's avatar thomwolf Committed by Morgan Funtowicz
Browse files

compatibility with sklearn and keras

parent b81ab431
...@@ -83,7 +83,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -83,7 +83,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.logger.info('Loading model {}'.format(args.model_name)) self.logger.info('Loading model {}'.format(args.model_name))
self.model_name = args.model_name self.model_name = args.model_name
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name) self.pipeline = AutoTokenizer.from_pretrained(args.model_name)
if args.task == 'text_classification': if args.task == 'text_classification':
self.model = SequenceClassifModel.from_pretrained(args.model_name) self.model = SequenceClassifModel.from_pretrained(args.model_name)
elif args.task == 'token_classification': elif args.task == 'token_classification':
......
...@@ -222,7 +222,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -222,7 +222,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
batch_length = max(len(input_ids) for input_ids in all_input_ids) batch_length = max(len(input_ids) for input_ids in all_input_ids)
features = [] features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, examples)): for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logger.info("Writing example %d", ex_index) logger.info("Writing example %d", ex_index)
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
......
...@@ -109,25 +109,9 @@ class TextClassificationPipeline(object): ...@@ -109,25 +109,9 @@ class TextClassificationPipeline(object):
self.tokenizer.save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory)
def compile(self, learning_rate=3e-5, epsilon=1e-8):
if self.framework == 'tf':
logger.info('Preparing model')
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
else:
raise NotImplementedError
self.is_compiled = True
def prepare_data(self, train_samples_text, train_samples_labels, def prepare_data(self, train_samples_text, train_samples_labels,
valid_samples_text=None, valid_samples_labels=None, valid_samples_text=None, valid_samples_labels=None,
validation_split=0.1): validation_split=0.1, **kwargs):
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text, dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
train_samples_labels) train_samples_labels)
num_data_samples = len(dataset) num_data_samples = len(dataset)
...@@ -138,7 +122,7 @@ class TextClassificationPipeline(object): ...@@ -138,7 +122,7 @@ class TextClassificationPipeline(object):
train_dataset = dataset train_dataset = dataset
num_train_samples = num_data_samples num_train_samples = num_data_samples
else: else:
assert 0.0 < validation_split < 1.0, "validation_split should be between 0.0 and 1.0" assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0"
num_valid_samples = int(num_data_samples * validation_split) num_valid_samples = int(num_data_samples * validation_split)
num_train_samples = num_data_samples - num_valid_samples num_train_samples = num_data_samples - num_valid_samples
train_dataset = dataset[num_train_samples] train_dataset = dataset[num_train_samples]
...@@ -150,14 +134,36 @@ class TextClassificationPipeline(object): ...@@ -150,14 +134,36 @@ class TextClassificationPipeline(object):
return train_dataset, valid_dataset, num_train_samples, num_valid_samples return train_dataset, valid_dataset, num_train_samples, num_valid_samples
def fit(self, train_samples_text, train_samples_labels, def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs):
if self.framework == 'tf':
logger.info('Preparing model')
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
else:
raise NotImplementedError
self.is_compiled = True
def fit(self, train_samples_text=None, train_samples_labels=None,
valid_samples_text=None, valid_samples_labels=None, valid_samples_text=None, valid_samples_labels=None,
train_batch_size=None, valid_batch_size=None, train_batch_size=None, valid_batch_size=None,
validation_split=0.1, validation_split=0.1,
**kwargs): **kwargs):
# Generic compatibility with sklearn and Keras
if 'y' in kwargs and train_samples_labels is None:
train_samples_labels = kwargs.pop('y')
if 'X' in kwargs and train_samples_text is None:
train_samples_text = kwargs.pop('X')
if not self.is_compiled: if not self.is_compiled:
self.compile() self.compile(**kwargs)
datasets = self.prepare_data(train_samples_text, train_samples_labels, datasets = self.prepare_data(train_samples_text, train_samples_labels,
valid_samples_text, valid_samples_labels, valid_samples_text, valid_samples_labels,
...@@ -180,11 +186,32 @@ class TextClassificationPipeline(object): ...@@ -180,11 +186,32 @@ class TextClassificationPipeline(object):
self.is_trained = True self.is_trained = True
def __call__(self, text): def fit_transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
self.fit(*texts, **kwargs)
return self(*texts, **kwargs)
def transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def predict(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
if not self.is_trained: if not self.is_trained:
logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.") logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.")
inputs = self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors=self.framework) inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework)
if self.framework == 'tf': if self.framework == 'tf':
# TODO trace model # TODO trace model
predictions = self.model(**inputs)[0] predictions = self.model(**inputs)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment