Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cb9db101
Commit
cb9db101
authored
Aug 04, 2019
by
Julien Chaumond
Browse files
Python 2 must DIE
parent
05c08352
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
8 deletions
+12
-8
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+3
-3
pytorch_transformers/tests/tokenization_roberta_test.py
pytorch_transformers/tests/tokenization_roberta_test.py
+6
-4
pytorch_transformers/tokenization_roberta.py
pytorch_transformers/tokenization_roberta.py
+3
-1
No files found.
pytorch_transformers/modeling_roberta.py
View file @
cb9db101
...
...
@@ -58,7 +58,7 @@ class RobertaEmbeddings(BertEmbeddings):
# cf. fairseq's `utils.make_positions`
position_ids
=
torch
.
arange
(
self
.
padding_idx
+
1
,
seq_length
+
self
.
padding_idx
+
1
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
return
super
().
forward
(
input_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
return
super
(
RobertaEmbeddings
,
self
).
forward
(
input_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
class
RobertaConfig
(
BertConfig
):
...
...
@@ -109,8 +109,8 @@ class RobertaForMaskedLM(BertPreTrainedModel):
class
RobertaLMHead
(
nn
.
Module
):
"""Roberta Head for masked language modeling."""
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
def
__init__
(
self
,
config
):
super
(
RobertaLMHead
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
layer_norm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
...
...
pytorch_transformers/tests/tokenization_roberta_test.py
View file @
cb9db101
...
...
@@ -18,6 +18,7 @@ from __future__ import (absolute_import, division, print_function,
import
os
import
unittest
import
pytest
import
six
from
pytorch_transformers.tokenization_roberta
import
RobertaTokenizer
...
...
@@ -31,6 +32,7 @@ class RobertaTokenizationTest(unittest.TestCase):
tokenizer
.
encode
(
'Hello world!'
),
[
0
,
31414
,
232
,
328
,
2
]
)
if
six
.
PY3
:
self
.
assertListEqual
(
tokenizer
.
encode
(
'Hello world! cécé herlolip'
),
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]
...
...
pytorch_transformers/tokenization_roberta.py
View file @
cb9db101
...
...
@@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function,
import
json
import
logging
import
re
from
io
import
open
import
six
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_gpt2
import
GPT2Tokenizer
...
...
@@ -125,7 +127,7 @@ class Dictionary(object):
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if
isinstance
(
f
,
s
tr
):
if
isinstance
(
f
,
s
ix
.
string_types
):
try
:
if
not
ignore_utf_errors
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment