"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "aee596d52363ac052c0cbea43536346f007d4db9"
Commit a171c2dd authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added readme code for inference with GLUE finetuned model

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/820

Differential Revision: D16783469

fbshipit-source-id: d5af8ba6a6685608d67b72d584952b8e43eabf9f
parent 577e4fa7
...@@ -65,3 +65,35 @@ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calcul ...@@ -65,3 +65,35 @@ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calcul
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
```python
from fairseq.models.roberta import RobertaModel
roberta = RobertaModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment