"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "44508ed0db269ed0b7c952fbee6bd09105a1c653"
Commit ce7f044b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add instructions to load RoBERTa models on PyTorch 1.0

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/921

Differential Revision: D16541025

Pulled By: myleott

fbshipit-source-id: bb78d30fe285da2adfc7c4e5897ee01fa413b2e4
parent 8d036c2f
......@@ -14,15 +14,47 @@ Model | Description | # params | Download
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
## Example usage (torch.hub)
## Results
##### Results on GLUE tasks (dev set, single model, single-task finetuning)
##### Load RoBERTa:
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SQuAD (dev set)
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
##### Results on Reading Comprehension (RACE, test set)
Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3
## Example usage
##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
```
>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```
##### Load RoBERTa (for PyTorch 1.0):
```
$ wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
$ tar -xzvf roberta.large.tar.gz
>>> from fairseq.models.roberta import RobertaModel
>>> roberta = RobertaModel.from_pretrained('/path/to/roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```
##### Apply Byte-Pair Encoding (BPE) to input text:
```
>>> tokens = roberta.encode('Hello world!')
......@@ -80,29 +112,7 @@ tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
```
## Results
##### Results on GLUE tasks (dev set, single model, single-task finetuning)
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SQuAD (dev set)
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
##### Results on Reading Comprehension (RACE, test set)
Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3
## Evaluating the `roberta.large.mnli` model
##### Evaluating the `roberta.large.mnli` model
Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment