Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9d060314
Commit
9d060314
authored
Aug 08, 2019
by
Julien Chaumond
Browse files
[RoBERTa] RobertaForSequenceClassification + conversion
parent
d2cc6b10
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
11 deletions
+140
-11
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
...rch_transformers/convert_roberta_checkpoint_to_pytorch.py
+25
-11
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+57
-0
pytorch_transformers/tests/modeling_roberta_test.py
pytorch_transformers/tests/modeling_roberta_test.py
+58
-0
No files found.
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
View file @
9d060314
...
...
@@ -30,6 +30,7 @@ from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
BertSelfOutput
)
from
pytorch_transformers.modeling_roberta
import
(
RobertaEmbeddings
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaModel
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -38,7 +39,7 @@ logger = logging.getLogger(__name__)
SAMPLE_TEXT
=
'Hello world! cécé herlolip'
def
convert_roberta_checkpoint_to_pytorch
(
roberta_checkpoint_path
,
pytorch_dump_folder_path
):
def
convert_roberta_checkpoint_to_pytorch
(
roberta_checkpoint_path
,
pytorch_dump_folder_path
,
classification_head
):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
...
...
@@ -53,9 +54,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
max_position_embeddings
=
514
,
type_vocab_size
=
1
,
)
if
classification_head
:
config
.
num_labels
=
roberta
.
args
.
num_classes
print
(
"Our BERT config:"
,
config
)
model
=
RobertaForMaskedLM
(
config
)
model
=
RobertaForSequenceClassification
(
config
)
if
classification_head
else
RobertaForMaskedLM
(
config
)
model
.
eval
()
# Now let's copy all the weights.
...
...
@@ -117,13 +120,19 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
bert_output
.
LayerNorm
.
variance_epsilon
=
roberta_layer
.
final_layer_norm
.
eps
#### end of layer
if
classification_head
:
model
.
classifier
.
dense
.
weight
=
roberta
.
model
.
classification_heads
[
'mnli'
].
dense
.
weight
model
.
classifier
.
dense
.
bias
=
roberta
.
model
.
classification_heads
[
'mnli'
].
dense
.
bias
model
.
classifier
.
out_proj
.
weight
=
roberta
.
model
.
classification_heads
[
'mnli'
].
out_proj
.
weight
model
.
classifier
.
out_proj
.
bias
=
roberta
.
model
.
classification_heads
[
'mnli'
].
out_proj
.
bias
else
:
# LM Head
model
.
lm_head
.
dense
.
weight
=
roberta
.
model
.
decoder
.
lm_head
.
dense
.
weight
model
.
lm_head
.
dense
.
bias
=
roberta
.
model
.
decoder
.
lm_head
.
dense
.
bias
model
.
lm_head
.
layer_norm
.
weight
=
roberta
.
model
.
decoder
.
lm_head
.
layer_norm
.
weight
model
.
lm_head
.
layer_norm
.
bias
=
roberta
.
model
.
decoder
.
lm_head
.
layer_norm
.
bias
model
.
lm_head
.
layer_norm
.
variance_epsilon
=
roberta
.
model
.
decoder
.
lm_head
.
layer_norm
.
eps
model
.
lm_head
.
decoder
.
weight
=
roberta
.
model
.
decoder
.
lm_head
.
weight
model
.
lm_head
.
weight
=
roberta
.
model
.
decoder
.
lm_head
.
weight
model
.
lm_head
.
bias
=
roberta
.
model
.
decoder
.
lm_head
.
bias
# Let's check that we get the same results.
...
...
@@ -157,8 +166,13 @@ if __name__ == "__main__":
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--classification_head"
,
action
=
"store_true"
,
help
=
"Whether to convert a final classification head."
)
args
=
parser
.
parse_args
()
convert_roberta_checkpoint_to_pytorch
(
args
.
roberta_checkpoint_path
,
args
.
pytorch_dump_folder_path
args
.
pytorch_dump_folder_path
,
args
.
classification_head
)
pytorch_transformers/modeling_roberta.py
View file @
9d060314
...
...
@@ -142,3 +142,60 @@ class RobertaLMHead(nn.Module):
x
=
self
.
decoder
(
x
)
+
self
.
bias
return
x
class
RobertaForSequenceClassification
(
BertPreTrainedModel
):
"""
Roberta Model with a classifier head on top.
"""
config_class
=
RobertaConfig
pretrained_model_archive_map
=
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"roberta"
def
__init__
(
self
,
config
):
super
(
RobertaForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
roberta
=
RobertaModel
(
config
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
roberta
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
config
):
super
(
RobertaClassificationHead
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
def
forward
(
self
,
features
,
**
kwargs
):
x
=
features
[:,
0
,
:]
# take <s> token (equiv. to [CLS])
x
=
self
.
dropout
(
x
)
x
=
self
.
dense
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
pytorch_transformers/tests/modeling_roberta_test.py
View file @
9d060314
...
...
@@ -179,5 +179,63 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
class
RobertaModelIntegrationTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
slow
def
test_inference_masked_lm
(
self
):
model
=
RobertaForMaskedLM
.
from_pretrained
(
'roberta-base'
)
input_ids
=
torch
.
tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
output
=
model
(
input_ids
)[
0
]
expected_shape
=
torch
.
Size
((
1
,
11
,
50265
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
# compare the actual values for a slice.
expected_slice
=
torch
.
Tensor
(
[[[
33.8843
,
-
4.3107
,
22.7779
],
[
4.6533
,
-
2.8099
,
13.6252
],
[
1.8222
,
-
3.6898
,
8.8600
]]]
)
self
.
assertTrue
(
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
)
@
pytest
.
mark
.
slow
def
test_inference_no_head
(
self
):
model
=
RobertaModel
.
from_pretrained
(
'roberta-base'
)
input_ids
=
torch
.
tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
output
=
model
(
input_ids
)[
0
]
# compare the actual values for a slice.
expected_slice
=
torch
.
Tensor
(
[[[
-
0.0231
,
0.0782
,
0.0074
],
[
-
0.1854
,
0.0539
,
-
0.0174
],
[
0.0548
,
0.0799
,
0.1687
]]]
)
self
.
assertTrue
(
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
)
@
pytest
.
mark
.
slow
def
test_inference_classification_head
(
self
):
model
=
RobertaForSequenceClassification
.
from_pretrained
(
'roberta-large-mnli'
)
input_ids
=
torch
.
tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
output
=
model
(
input_ids
)[
0
]
expected_shape
=
torch
.
Size
((
1
,
3
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
expected_tensor
=
torch
.
Tensor
([[
-
0.9469
,
0.3913
,
0.5118
]])
self
.
assertTrue
(
torch
.
allclose
(
output
,
expected_tensor
,
atol
=
1e-3
)
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment