Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ea636440
Commit
ea636440
authored
Dec 17, 2019
by
Julien Chaumond
Browse files
[roberta.conversion] Do not hardcode vocab size
and support for fairseq 0.9+
parent
a4df2e01
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
...convert_roberta_original_pytorch_checkpoint_to_pytorch.py
+8
-2
No files found.
transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
View file @
ea636440
...
...
@@ -22,6 +22,12 @@ import numpy as np
import
torch
import
pathlib
import
fairseq
from
packaging
import
version
if
version
.
parse
(
fairseq
.
__version__
)
<
version
.
parse
(
"0.9.0"
):
raise
Exception
(
"requires fairseq >= 0.9.0"
)
from
fairseq.models.roberta
import
RobertaModel
as
FairseqRobertaModel
from
fairseq.modules
import
TransformerSentenceEncoderLayer
from
transformers.modeling_bert
import
(
BertConfig
,
BertEncoder
,
...
...
@@ -46,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
"""
roberta
=
FairseqRobertaModel
.
from_pretrained
(
roberta_checkpoint_path
)
roberta
.
eval
()
# disable dropout
roberta_sent_encoder
=
roberta
.
model
.
decoder
.
sentence_encoder
config
=
BertConfig
(
vocab_size
=
50265
,
vocab_size
=
roberta_sent_encoder
.
embed_tokens
.
num_embeddings
,
hidden_size
=
roberta
.
args
.
encoder_embed_dim
,
num_hidden_layers
=
roberta
.
args
.
encoder_layers
,
num_attention_heads
=
roberta
.
args
.
encoder_attention_heads
,
...
...
@@ -65,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Now let's copy all the weights.
# Embeddings
roberta_sent_encoder
=
roberta
.
model
.
decoder
.
sentence_encoder
model
.
roberta
.
embeddings
.
word_embeddings
.
weight
=
roberta_sent_encoder
.
embed_tokens
.
weight
model
.
roberta
.
embeddings
.
position_embeddings
.
weight
=
roberta_sent_encoder
.
embed_positions
.
weight
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
.
data
=
torch
.
zeros_like
(
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
)
# just zero them out b/c RoBERTa doesn't use them.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment