Unverified Commit 5aa8a278 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix roberta checkpoint conversion script (#3642)

parent 11cc1e16
......@@ -25,15 +25,8 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version
from transformers.modeling_bert import (
BertConfig,
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
......@@ -55,7 +48,7 @@ def convert_roberta_checkpoint_to_pytorch(
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
config = BertConfig(
config = RobertaConfig(
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
......@@ -138,7 +131,7 @@ def convert_roberta_checkpoint_to_pytorch(
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment