Commit 1a2b40cb authored by Lysandre's avatar Lysandre
Browse files

run_tf_glue MRPC evaluation only for MRPC

parent be36cf92
...@@ -71,20 +71,21 @@ history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps, ...@@ -71,20 +71,21 @@ history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
os.makedirs('./save/', exist_ok=True) os.makedirs('./save/', exist_ok=True)
model.save_pretrained('./save/') model.save_pretrained('./save/')
# Load the TensorFlow model in PyTorch for inspection if TASK == "mrpc":
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True) # Load the TensorFlow model in PyTorch for inspection
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
sentence_0 = 'This research was consistent with his findings.' # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
sentence_1 = 'His findings were compatible with this research.' sentence_0 = 'This research was consistent with his findings.'
sentence_2 = 'His findings were not compatible with this research.' sentence_1 = 'His findings were compatible with this research.'
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt') sentence_2 = 'His findings were not compatible with this research.'
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt') inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
del inputs_1["special_tokens_mask"]
del inputs_2["special_tokens_mask"] del inputs_1["special_tokens_mask"]
del inputs_2["special_tokens_mask"]
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
pred_2 = pytorch_model(**inputs_2)[0].argmax().item() pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0') pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0') print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment