Commit a171c2dd authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added readme code for inference with GLUE finetuned model

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/820

Differential Revision: D16783469

fbshipit-source-id: d5af8ba6a6685608d67b72d584952b8e43eabf9f
parent 577e4fa7
......@@ -65,3 +65,35 @@ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calcul
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
```python
from fairseq.models.roberta import RobertaModel
roberta = RobertaModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment