Commit ce7f044b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add instructions to load RoBERTa models on PyTorch 1.0

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/921

Differential Revision: D16541025

Pulled By: myleott

fbshipit-source-id: bb78d30fe285da2adfc7c4e5897ee01fa413b2e4
parent 8d036c2f
...@@ -14,15 +14,47 @@ Model | Description | # params | Download ...@@ -14,15 +14,47 @@ Model | Description | # params | Download
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz) `roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz) `roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
## Example usage (torch.hub) ## Results
##### Results on GLUE tasks (dev set, single model, single-task finetuning)
##### Load RoBERTa: Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SQuAD (dev set)
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
##### Results on Reading Comprehension (RACE, test set)
Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3
## Example usage
##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
``` ```
>>> import torch >>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large') >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune) >>> roberta.eval() # disable dropout (or leave in train mode to finetune)
``` ```
##### Load RoBERTa (for PyTorch 1.0):
```
$ wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
$ tar -xzvf roberta.large.tar.gz
>>> from fairseq.models.roberta import RobertaModel
>>> roberta = RobertaModel.from_pretrained('/path/to/roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```
##### Apply Byte-Pair Encoding (BPE) to input text: ##### Apply Byte-Pair Encoding (BPE) to input text:
``` ```
>>> tokens = roberta.encode('Hello world!') >>> tokens = roberta.encode('Hello world!')
...@@ -80,29 +112,7 @@ tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>) ...@@ -80,29 +112,7 @@ tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>) tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
``` ```
## Results ##### Evaluating the `roberta.large.mnli` model
##### Results on GLUE tasks (dev set, single model, single-task finetuning)
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SQuAD (dev set)
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
##### Results on Reading Comprehension (RACE, test set)
Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3
## Evaluating the `roberta.large.mnli` model
Example python code snippet to evaluate accuracy on the MNLI dev_matched set. Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment