"git@developer.sourcefind.cn:change/sglang.git" did not exist on "066f4ec91f7becd1d490745c856c4a57ede9d1ca"
Commit 17fcc72a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add RoBERTa README

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/778

Differential Revision: D16525447

Pulled By: myleott

fbshipit-source-id: e721e3a10e243a2408a04f89f06b5adbbe2fdff2
parent 8835d93c
# RoBERTa: A Robustly Optimized BERT Pretraining Approach
*Pre-print coming 7/28*
## Introduction
**RoBERTa** iterates on BERT's pretraining procedure, including training the model longer, with bigger batches over more data; removing the next sentence prediction objective; training on longer sequences; and dynamically changing the masking pattern applied to the training data. See the associated paper for more details.
## Pre-trained models
Model | Description | # params | Download
---|---|---|---
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
## Example usage (torch.hub)
##### Load RoBERTa:
```
>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
```
##### Apply Byte-Pair Encoding (BPE) to input text:
```
>>> tokens = roberta.encode('Hello world!')
>>> tokens
tensor([ 0, 31414, 232, 328, 2])
```
##### Extract features from RoBERTa:
```
>>> features = roberta.extract_features(tokens)
>>> features.size()
torch.Size([1, 5, 1024])
```
##### Use RoBERTa for sentence-pair classification tasks:
```
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned
>>> roberta.eval() # disable dropout for evaluation
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is not very optimized.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(0) # contradiction
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is based on BERT.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(2) # entailment
```
##### Register a new (randomly initialized) classification head:
```
>>> roberta.register_classification_head('new_task', num_classes=3)
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
```
##### Using the GPU:
```
>>> roberta.cuda()
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
```
## Results
##### Results on GLUE tasks (dev set, single model, single-task finetuning)
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SQuAD (dev set)
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
##### Results on Reading Comprehension (RACE, test set)
Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3
## Evaluating the `roberta.large.mnli` model
Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
```
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('glue_data/MNLI/dev_matched.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('mnli', tokens).argmax().item()
prediction_label = label_map[prediction]
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# Expected output: 0.9060
```
## Finetuning on GLUE tasks
A more detailed tutorial is coming soon.
## Pretraining using your own data
You can use the [`masked_lm` task](/fairseq/tasks/masked_lm.py) to pretrain RoBERTa from scratch, or to continue pretraining RoBERTa starting from one of the released checkpoints.
Data should be preprocessed following the [language modeling example](/examples/language_model).
A more detailed tutorial is coming soon.
## Citation
```bibtex
@article{liu2019roberta,
title = {RoBERTa: A Robustly Optimized BERT Pretraining Approach},
author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and
Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and
Luke Zettlemoyer and Veselin Stoyanov},
year = {2019},
}
```
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
class RobertaHubInterface(nn.Module):
"""A simple PyTorch Hub interface to RoBERTa.
Load RoBERTa::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
Apply Byte-Pair Encoding (BPE) to input text::
>>> tokens = roberta.encode('Hello world!')
>>> tokens
tensor([ 0, 31414, 232, 328, 2])
Extract features from RoBERTa::
>>> features = roberta.extract_features(tokens)
>>> features.size()
torch.Size([1, 5, 1024])
Use RoBERTa for sentence-pair classification tasks::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned
>>> roberta.eval() # disable dropout for evaluation
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is not very optimized.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(0) # contradiction
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is based on BERT.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(2) # entailment
Register a new (randomly initialized) classification head::
>>> roberta.register_classification_head('new_task', num_classes=3)
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
Using the GPU::
>>> roberta.cuda()
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
"""
def __init__(self, args, task, model):
super().__init__()
self.args = args
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
for s in addl_sentences:
bpe_sentence += ' </s> ' + self.bpe.encode(s)
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True)
return tokens.long()
def extract_features(self, tokens: torch.LongTensor) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, _ = self.model(tokens.to(device=self.device), features_only=True)
return features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor):
features = self.extract_features(tokens)
logits = self.model.classification_heads[head](features)
return F.log_softmax(logits, dim=-1)
@register_model('roberta')
class RobertaModel(FairseqLanguageModel):
@classmethod
def hub_models(cls):
return {
'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
}
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--encoder-layers', type=int, metavar='L',
help='num encoder layers')
parser.add_argument('--encoder-embed-dim', type=int, metavar='H',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='F',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-attention-heads', type=int, metavar='A',
help='num encoder attention heads')
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--pooler-activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use for pooler layer')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN')
parser.add_argument('--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample
encoder = RobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
self.classification_heads[name] = RobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
@property
def supported_targets(self):
return {'self'}
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe='gpt2',
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
# recreate any classification heads present in the state dict
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[
prefix + 'classification_heads.' + head_name + '.out_proj.weight'
].size(0)
inner_dim = state_dict[
prefix + 'classification_heads.' + head_name + '.dense.weight'
].size(0)
self.register_classification_head(head_name, num_classes, inner_dim)
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, **kwargs):
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class RobertaEncoder(FairseqDecoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=True,
apply_bert_init=True,
activation_fn=args.activation_fn,
)
self.lm_head = RobertaLMHead(
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.sentence_encoder.embed_tokens.weight,
)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
inner_states, _ = self.sentence_encoder(
src_tokens, last_state_only=not return_all_hiddens,
)
features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None}
def output_layer(self, features, **unused):
return self.lm_head(features)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
@register_model_architecture('roberta', 'roberta')
def base_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
@register_model_architecture('roberta', 'roberta_base')
def roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture('roberta', 'roberta_large')
def roberta_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment