Commit 7c9f8f93 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent d6dde438
import tensorflow as tf
import tensorflow_datasets
from pytorch_transformers import BertTokenizer, BertForSequenceClassification, TFBertForSequenceClassification, glue_convert_examples_to_features
from transformers import *
# Load tokenizer, model, dataset
# Load dataset, tokenizer, model from pretrained model/vocabulary
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tf_model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
dataset = tensorflow_datasets.load("glue/mrpc")
dataset = tensorflow_datasets.load('glue/mrpc')
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
# Prepare dataset for GLUE
train_dataset = glue_convert_examples_to_features(dataset['train'], tokenizer, task='mrpc', max_length=128)
valid_dataset = glue_convert_examples_to_features(dataset['validation'], tokenizer, task='mrpc', max_length=128)
# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(dataset['train'], tokenizer, task='mrpc')
valid_dataset = glue_convert_examples_to_features(dataset['validation'], tokenizer, task='mrpc')
train_dataset = train_dataset.shuffle(100).batch(32).repeat(3)
valid_dataset = valid_dataset.batch(64)
# Compile tf.keras model for training
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(2e-5, 345, end_learning_rate=0)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
tf_model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
# Train and evaluate using tf.keras.Model.fit()
tf_model.fit(train_dataset, epochs=3, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
model.fit(train_dataset, epochs=3, steps_per_epoch=115,
validation_data=valid_dataset, validation_steps=7)
# Save the model and load it in PyTorch
tf_model.save_pretrained('./runs/')
pt_model = BertForSequenceClassification.from_pretrained('./runs/', from_tf=True)
# Save the TensorFlow model and load it in PyTorch
model.save_pretrained('./save/')
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
# Quickly inspect a few predictions
inputs = tokenizer.encode_plus("I said the company is doing great", "The company has good results", add_special_tokens=True, return_tensors='pt')
pred = pt_model(torch.tensor(tokens))
# Quickly inspect a few predictions - MRPC is a paraphrasing task
inputs = tokenizer.encode_plus("The company is doing great",
"The company has good results",
add_special_tokens=True,
return_tensors='pt')
pred = pytorch_model(**inputs)
print("Paraphrase" if pred.argmax().item() == 0 else "Not paraphrase")
......@@ -199,13 +199,13 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='query')
self.key = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='key')
self.value = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='value')
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
......@@ -260,7 +260,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertSelfOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
......@@ -296,7 +296,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertIntermediate, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
......@@ -313,7 +313,7 @@ class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
......@@ -383,7 +383,7 @@ class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
activation='tanh',
name='dense')
......@@ -399,7 +399,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPredictionHeadTransform, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
......@@ -452,7 +452,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='seq_relationship')
def call(self, pooled_output):
......@@ -843,7 +843,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
......@@ -895,7 +895,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
......@@ -974,7 +974,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
......@@ -1026,7 +1026,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name='qa_outputs')
def call(self, inputs, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment